1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "lsp/Logging.h"
11 #include "lsp/Protocol.h"
12 #include "mlir/IR/FunctionInterfaces.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Parser/AsmParserState.h"
15 #include "mlir/Parser/Parser.h"
16 #include "llvm/Support/SourceMgr.h"
17 
18 using namespace mlir;
19 
20 /// Returns a language server position for the given source location.
21 static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, SMLoc loc) {
22   std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc);
23   lsp::Position pos;
24   pos.line = lineAndCol.first - 1;
25   pos.character = lineAndCol.second - 1;
26   return pos;
27 }
28 
29 /// Returns a source location from the given language server position.
30 static SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) {
31   return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), pos.line + 1,
32                                      pos.character);
33 }
34 
35 /// Returns a language server range for the given source range.
36 static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, SMRange range) {
37   return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)};
38 }
39 
40 /// Returns a language server location from the given source range.
41 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr,
42                                         SMRange range,
43                                         const lsp::URIForFile &uri) {
44   return lsp::Location{uri, getRangeFromLoc(mgr, range)};
45 }
46 
47 /// Returns a language server location from the given MLIR file location.
48 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) {
49   llvm::Expected<lsp::URIForFile> sourceURI =
50       lsp::URIForFile::fromFile(loc.getFilename());
51   if (!sourceURI) {
52     lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
53                        loc.getFilename(),
54                        llvm::toString(sourceURI.takeError()));
55     return llvm::None;
56   }
57 
58   lsp::Position position;
59   position.line = loc.getLine() - 1;
60   position.character = loc.getColumn();
61   return lsp::Location{*sourceURI, lsp::Range(position)};
62 }
63 
64 /// Returns a language server location from the given MLIR location, or None if
65 /// one couldn't be created. `uri` is an optional additional filter that, when
66 /// present, is used to filter sub locations that do not share the same uri.
67 static Optional<lsp::Location>
68 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
69                    const lsp::URIForFile *uri = nullptr) {
70   Optional<lsp::Location> location;
71   loc->walk([&](Location nestedLoc) {
72     FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
73     if (!fileLoc)
74       return WalkResult::advance();
75 
76     Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
77     if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
78       location = *sourceLoc;
79       SMLoc loc = sourceMgr.FindLocForLineAndColumn(
80           sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
81 
82       // Use range of potential identifier starting at location, else length 1
83       // range.
84       location->range.end.character += 1;
85       if (Optional<SMRange> range =
86               AsmParserState::convertIdLocToRange(loc)) {
87         auto lineCol = sourceMgr.getLineAndColumn(range->End);
88         location->range.end.character =
89             std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
90       }
91       return WalkResult::interrupt();
92     }
93     return WalkResult::advance();
94   });
95   return location;
96 }
97 
98 /// Collect all of the locations from the given MLIR location that are not
99 /// contained within the given URI.
100 static void collectLocationsFromLoc(Location loc,
101                                     std::vector<lsp::Location> &locations,
102                                     const lsp::URIForFile &uri) {
103   SetVector<Location> visitedLocs;
104   loc->walk([&](Location nestedLoc) {
105     FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
106     if (!fileLoc || !visitedLocs.insert(nestedLoc))
107       return WalkResult::advance();
108 
109     Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
110     if (sourceLoc && sourceLoc->uri != uri)
111       locations.push_back(*sourceLoc);
112     return WalkResult::advance();
113   });
114 }
115 
116 /// Returns true if the given range contains the given source location. Note
117 /// that this has slightly different behavior than SMRange because it is
118 /// inclusive of the end location.
119 static bool contains(SMRange range, SMLoc loc) {
120   return range.Start.getPointer() <= loc.getPointer() &&
121          loc.getPointer() <= range.End.getPointer();
122 }
123 
124 /// Returns true if the given location is contained by the definition or one of
125 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
126 /// the range within `def` that the provided `loc` overlapped with.
127 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
128                        SMRange *overlappedRange = nullptr) {
129   // Check the main definition.
130   if (contains(def.loc, loc)) {
131     if (overlappedRange)
132       *overlappedRange = def.loc;
133     return true;
134   }
135 
136   // Check the uses.
137   const auto *useIt = llvm::find_if(def.uses, [&](const SMRange &range) {
138     return contains(range, loc);
139   });
140   if (useIt != def.uses.end()) {
141     if (overlappedRange)
142       *overlappedRange = *useIt;
143     return true;
144   }
145   return false;
146 }
147 
148 /// Given a location pointing to a result, return the result number it refers
149 /// to or None if it refers to all of the results.
150 static Optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
151   // Skip all of the identifier characters.
152   auto isIdentifierChar = [](char c) {
153     return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
154            c == '-';
155   };
156   const char *curPtr = loc.getPointer();
157   while (isIdentifierChar(*curPtr))
158     ++curPtr;
159 
160   // Check to see if this location indexes into the result group, via `#`. If it
161   // doesn't, we can't extract a sub result number.
162   if (*curPtr != '#')
163     return llvm::None;
164 
165   // Compute the sub result number from the remaining portion of the string.
166   const char *numberStart = ++curPtr;
167   while (llvm::isDigit(*curPtr))
168     ++curPtr;
169   StringRef numberStr(numberStart, curPtr - numberStart);
170   unsigned resultNumber = 0;
171   return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>()
172                                                     : resultNumber;
173 }
174 
175 /// Given a source location range, return the text covered by the given range.
176 /// If the range is invalid, returns None.
177 static Optional<StringRef> getTextFromRange(SMRange range) {
178   if (!range.isValid())
179     return None;
180   const char *startPtr = range.Start.getPointer();
181   return StringRef(startPtr, range.End.getPointer() - startPtr);
182 }
183 
184 /// Given a block, return its position in its parent region.
185 static unsigned getBlockNumber(Block *block) {
186   return std::distance(block->getParent()->begin(), block->getIterator());
187 }
188 
189 /// Given a block and source location, print the source name of the block to the
190 /// given output stream.
191 static void printDefBlockName(raw_ostream &os, Block *block,
192                               SMRange loc = {}) {
193   // Try to extract a name from the source location.
194   Optional<StringRef> text = getTextFromRange(loc);
195   if (text && text->startswith("^")) {
196     os << *text;
197     return;
198   }
199 
200   // Otherwise, we don't have a name so print the block number.
201   os << "<Block #" << getBlockNumber(block) << ">";
202 }
203 static void printDefBlockName(raw_ostream &os,
204                               const AsmParserState::BlockDefinition &def) {
205   printDefBlockName(os, def.block, def.definition.loc);
206 }
207 
208 /// Convert the given MLIR diagnostic to the LSP form.
209 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
210                                                Diagnostic &diag,
211                                                const lsp::URIForFile &uri) {
212   lsp::Diagnostic lspDiag;
213   lspDiag.source = "mlir";
214 
215   // Note: Right now all of the diagnostics are treated as parser issues, but
216   // some are parser and some are verifier.
217   lspDiag.category = "Parse Error";
218 
219   // Try to grab a file location for this diagnostic.
220   // TODO: For simplicity, we just grab the first one. It may be likely that we
221   // will need a more interesting heuristic here.'
222   Optional<lsp::Location> lspLocation =
223       getLocationFromLoc(sourceMgr, diag.getLocation(), &uri);
224   if (lspLocation)
225     lspDiag.range = lspLocation->range;
226 
227   // Convert the severity for the diagnostic.
228   switch (diag.getSeverity()) {
229   case DiagnosticSeverity::Note:
230     llvm_unreachable("expected notes to be handled separately");
231   case DiagnosticSeverity::Warning:
232     lspDiag.severity = lsp::DiagnosticSeverity::Warning;
233     break;
234   case DiagnosticSeverity::Error:
235     lspDiag.severity = lsp::DiagnosticSeverity::Error;
236     break;
237   case DiagnosticSeverity::Remark:
238     lspDiag.severity = lsp::DiagnosticSeverity::Information;
239     break;
240   }
241   lspDiag.message = diag.str();
242 
243   // Attach any notes to the main diagnostic as related information.
244   std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
245   for (Diagnostic &note : diag.getNotes()) {
246     lsp::Location noteLoc;
247     if (Optional<lsp::Location> loc =
248             getLocationFromLoc(sourceMgr, note.getLocation()))
249       noteLoc = *loc;
250     else
251       noteLoc.uri = uri;
252     relatedDiags.emplace_back(noteLoc, note.str());
253   }
254   if (!relatedDiags.empty())
255     lspDiag.relatedInformation = std::move(relatedDiags);
256 
257   return lspDiag;
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // MLIRDocument
262 //===----------------------------------------------------------------------===//
263 
264 namespace {
265 /// This class represents all of the information pertaining to a specific MLIR
266 /// document.
267 struct MLIRDocument {
268   MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
269                StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
270   MLIRDocument(const MLIRDocument &) = delete;
271   MLIRDocument &operator=(const MLIRDocument &) = delete;
272 
273   //===--------------------------------------------------------------------===//
274   // Definitions and References
275   //===--------------------------------------------------------------------===//
276 
277   void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
278                       std::vector<lsp::Location> &locations);
279   void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
280                         std::vector<lsp::Location> &references);
281 
282   //===--------------------------------------------------------------------===//
283   // Hover
284   //===--------------------------------------------------------------------===//
285 
286   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
287                                  const lsp::Position &hoverPos);
288   Optional<lsp::Hover>
289   buildHoverForOperation(SMRange hoverRange,
290                          const AsmParserState::OperationDefinition &op);
291   lsp::Hover buildHoverForOperationResult(SMRange hoverRange,
292                                           Operation *op, unsigned resultStart,
293                                           unsigned resultEnd,
294                                           SMLoc posLoc);
295   lsp::Hover buildHoverForBlock(SMRange hoverRange,
296                                 const AsmParserState::BlockDefinition &block);
297   lsp::Hover
298   buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
299                              const AsmParserState::BlockDefinition &block);
300 
301   //===--------------------------------------------------------------------===//
302   // Document Symbols
303   //===--------------------------------------------------------------------===//
304 
305   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
306   void findDocumentSymbols(Operation *op,
307                            std::vector<lsp::DocumentSymbol> &symbols);
308 
309   //===--------------------------------------------------------------------===//
310   // Fields
311   //===--------------------------------------------------------------------===//
312 
313   /// The high level parser state used to find definitions and references within
314   /// the source file.
315   AsmParserState asmState;
316 
317   /// The container for the IR parsed from the input file.
318   Block parsedIR;
319 
320   /// The source manager containing the contents of the input file.
321   llvm::SourceMgr sourceMgr;
322 };
323 } // namespace
324 
325 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
326                            StringRef contents,
327                            std::vector<lsp::Diagnostic> &diagnostics) {
328   ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
329     diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
330   });
331 
332   // Try to parsed the given IR string.
333   auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
334   if (!memBuffer) {
335     lsp::Logger::error("Failed to create memory buffer for file", uri.file());
336     return;
337   }
338 
339   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
340   if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr,
341                              &asmState))) {
342     // If parsing failed, clear out any of the current state.
343     parsedIR.clear();
344     asmState = AsmParserState();
345     return;
346   }
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // MLIRDocument: Definitions and References
351 //===----------------------------------------------------------------------===//
352 
353 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
354                                   const lsp::Position &defPos,
355                                   std::vector<lsp::Location> &locations) {
356   SMLoc posLoc = getPosFromLoc(sourceMgr, defPos);
357 
358   // Functor used to check if an SM definition contains the position.
359   auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
360     if (!isDefOrUse(def, posLoc))
361       return false;
362     locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
363     return true;
364   };
365 
366   // Check all definitions related to operations.
367   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
368     if (contains(op.loc, posLoc))
369       return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
370     for (const auto &result : op.resultGroups)
371       if (containsPosition(result.definition))
372         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
373     for (const auto &symUse : op.symbolUses) {
374       if (contains(symUse, posLoc)) {
375         locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri));
376         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
377       }
378     }
379   }
380 
381   // Check all definitions related to blocks.
382   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
383     if (containsPosition(block.definition))
384       return;
385     for (const AsmParserState::SMDefinition &arg : block.arguments)
386       if (containsPosition(arg))
387         return;
388   }
389 }
390 
391 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
392                                     const lsp::Position &pos,
393                                     std::vector<lsp::Location> &references) {
394   // Functor used to append all of the definitions/uses of the given SM
395   // definition to the reference list.
396   auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
397     references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
398     for (const SMRange &use : def.uses)
399       references.push_back(getLocationFromLoc(sourceMgr, use, uri));
400   };
401 
402   SMLoc posLoc = getPosFromLoc(sourceMgr, pos);
403 
404   // Check all definitions related to operations.
405   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
406     if (contains(op.loc, posLoc)) {
407       for (const auto &result : op.resultGroups)
408         appendSMDef(result.definition);
409       for (const auto &symUse : op.symbolUses)
410         if (contains(symUse, posLoc))
411           references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
412       return;
413     }
414     for (const auto &result : op.resultGroups)
415       if (isDefOrUse(result.definition, posLoc))
416         return appendSMDef(result.definition);
417     for (const auto &symUse : op.symbolUses) {
418       if (!contains(symUse, posLoc))
419         continue;
420       for (const auto &symUse : op.symbolUses)
421         references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
422       return;
423     }
424   }
425 
426   // Check all definitions related to blocks.
427   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
428     if (isDefOrUse(block.definition, posLoc))
429       return appendSMDef(block.definition);
430 
431     for (const AsmParserState::SMDefinition &arg : block.arguments)
432       if (isDefOrUse(arg, posLoc))
433         return appendSMDef(arg);
434   }
435 }
436 
437 //===----------------------------------------------------------------------===//
438 // MLIRDocument: Hover
439 //===----------------------------------------------------------------------===//
440 
441 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
442                                              const lsp::Position &hoverPos) {
443   SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos);
444   SMRange hoverRange;
445 
446   // Check for Hovers on operations and results.
447   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
448     // Check if the position points at this operation.
449     if (contains(op.loc, posLoc))
450       return buildHoverForOperation(op.loc, op);
451 
452     // Check if the position points at the symbol name.
453     for (auto &use : op.symbolUses)
454       if (contains(use, posLoc))
455         return buildHoverForOperation(use, op);
456 
457     // Check if the position points at a result group.
458     for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
459       const auto &result = op.resultGroups[i];
460       if (!isDefOrUse(result.definition, posLoc, &hoverRange))
461         continue;
462 
463       // Get the range of results covered by the over position.
464       unsigned resultStart = result.startIndex;
465       unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
466                                         : op.resultGroups[i + 1].startIndex;
467       return buildHoverForOperationResult(hoverRange, op.op, resultStart,
468                                           resultEnd, posLoc);
469     }
470   }
471 
472   // Check to see if the hover is over a block argument.
473   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
474     if (isDefOrUse(block.definition, posLoc, &hoverRange))
475       return buildHoverForBlock(hoverRange, block);
476 
477     for (const auto &arg : llvm::enumerate(block.arguments)) {
478       if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
479         continue;
480 
481       return buildHoverForBlockArgument(
482           hoverRange, block.block->getArgument(arg.index()), block);
483     }
484   }
485   return llvm::None;
486 }
487 
488 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
489     SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
490   lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
491   llvm::raw_string_ostream os(hover.contents.value);
492 
493   // Add the operation name to the hover.
494   os << "\"" << op.op->getName() << "\"";
495   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
496     os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
497   os << "\n\n";
498 
499   os << "Generic Form:\n\n```mlir\n";
500 
501   // Temporary drop the regions of this operation so that they don't get
502   // printed in the output. This helps keeps the size of the output hover
503   // small.
504   SmallVector<std::unique_ptr<Region>> regions;
505   for (Region &region : op.op->getRegions()) {
506     regions.emplace_back(std::make_unique<Region>());
507     regions.back()->takeBody(region);
508   }
509 
510   op.op->print(
511       os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
512   os << "\n```\n";
513 
514   // Move the regions back to the current operation.
515   for (Region &region : op.op->getRegions())
516     region.takeBody(*regions.back());
517 
518   return hover;
519 }
520 
521 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
522                                                       Operation *op,
523                                                       unsigned resultStart,
524                                                       unsigned resultEnd,
525                                                       SMLoc posLoc) {
526   lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
527   llvm::raw_string_ostream os(hover.contents.value);
528 
529   // Add the parent operation name to the hover.
530   os << "Operation: \"" << op->getName() << "\"\n\n";
531 
532   // Check to see if the location points to a specific result within the
533   // group.
534   if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
535     if ((resultStart + *resultNumber) < resultEnd) {
536       resultStart += *resultNumber;
537       resultEnd = resultStart + 1;
538     }
539   }
540 
541   // Add the range of results and their types to the hover info.
542   if ((resultStart + 1) == resultEnd) {
543     os << "Result #" << resultStart << "\n\n"
544        << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
545   } else {
546     os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
547        << "Types: ";
548     llvm::interleaveComma(
549         op->getResults().slice(resultStart, resultEnd), os,
550         [&](Value result) { os << "`" << result.getType() << "`"; });
551   }
552 
553   return hover;
554 }
555 
556 lsp::Hover
557 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
558                                  const AsmParserState::BlockDefinition &block) {
559   lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
560   llvm::raw_string_ostream os(hover.contents.value);
561 
562   // Print the given block to the hover output stream.
563   auto printBlockToHover = [&](Block *newBlock) {
564     if (const auto *def = asmState.getBlockDef(newBlock))
565       printDefBlockName(os, *def);
566     else
567       printDefBlockName(os, newBlock);
568   };
569 
570   // Display the parent operation, block number, predecessors, and successors.
571   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
572      << "Block #" << getBlockNumber(block.block) << "\n\n";
573   if (!block.block->hasNoPredecessors()) {
574     os << "Predecessors: ";
575     llvm::interleaveComma(block.block->getPredecessors(), os,
576                           printBlockToHover);
577     os << "\n\n";
578   }
579   if (!block.block->hasNoSuccessors()) {
580     os << "Successors: ";
581     llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
582     os << "\n\n";
583   }
584 
585   return hover;
586 }
587 
588 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
589     SMRange hoverRange, BlockArgument arg,
590     const AsmParserState::BlockDefinition &block) {
591   lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
592   llvm::raw_string_ostream os(hover.contents.value);
593 
594   // Display the parent operation, block, the argument number, and the type.
595   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
596      << "Block: ";
597   printDefBlockName(os, block);
598   os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
599      << "Type: `" << arg.getType() << "`\n\n";
600 
601   return hover;
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // MLIRDocument: Document Symbols
606 //===----------------------------------------------------------------------===//
607 
608 void MLIRDocument::findDocumentSymbols(
609     std::vector<lsp::DocumentSymbol> &symbols) {
610   for (Operation &op : parsedIR)
611     findDocumentSymbols(&op, symbols);
612 }
613 
614 void MLIRDocument::findDocumentSymbols(
615     Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
616   std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
617 
618   // Check for the source information of this operation.
619   if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
620     // If this operation defines a symbol, record it.
621     if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
622       symbols.emplace_back(symbol.getName(),
623                            isa<FunctionOpInterface>(op)
624                                ? lsp::SymbolKind::Function
625                                : lsp::SymbolKind::Class,
626                            getRangeFromLoc(sourceMgr, def->scopeLoc),
627                            getRangeFromLoc(sourceMgr, def->loc));
628       childSymbols = &symbols.back().children;
629 
630     } else if (op->hasTrait<OpTrait::SymbolTable>()) {
631       // Otherwise, if this is a symbol table push an anonymous document symbol.
632       symbols.emplace_back("<" + op->getName().getStringRef() + ">",
633                            lsp::SymbolKind::Namespace,
634                            getRangeFromLoc(sourceMgr, def->scopeLoc),
635                            getRangeFromLoc(sourceMgr, def->loc));
636       childSymbols = &symbols.back().children;
637     }
638   }
639 
640   // Recurse into the regions of this operation.
641   if (!op->getNumRegions())
642     return;
643   for (Region &region : op->getRegions())
644     for (Operation &childOp : region.getOps())
645       findDocumentSymbols(&childOp, *childSymbols);
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // MLIRTextFileChunk
650 //===----------------------------------------------------------------------===//
651 
652 namespace {
653 /// This class represents a single chunk of an MLIR text file.
654 struct MLIRTextFileChunk {
655   MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
656                     const lsp::URIForFile &uri, StringRef contents,
657                     std::vector<lsp::Diagnostic> &diagnostics)
658       : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
659 
660   /// Adjust the line number of the given range to anchor at the beginning of
661   /// the file, instead of the beginning of this chunk.
662   void adjustLocForChunkOffset(lsp::Range &range) {
663     adjustLocForChunkOffset(range.start);
664     adjustLocForChunkOffset(range.end);
665   }
666   /// Adjust the line number of the given position to anchor at the beginning of
667   /// the file, instead of the beginning of this chunk.
668   void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
669 
670   /// The line offset of this chunk from the beginning of the file.
671   uint64_t lineOffset;
672   /// The document referred to by this chunk.
673   MLIRDocument document;
674 };
675 } // namespace
676 
677 //===----------------------------------------------------------------------===//
678 // MLIRTextFile
679 //===----------------------------------------------------------------------===//
680 
681 namespace {
682 /// This class represents a text file containing one or more MLIR documents.
683 class MLIRTextFile {
684 public:
685   MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
686                int64_t version, DialectRegistry &registry,
687                std::vector<lsp::Diagnostic> &diagnostics);
688 
689   /// Return the current version of this text file.
690   int64_t getVersion() const { return version; }
691 
692   //===--------------------------------------------------------------------===//
693   // LSP Queries
694   //===--------------------------------------------------------------------===//
695 
696   void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
697                       std::vector<lsp::Location> &locations);
698   void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
699                         std::vector<lsp::Location> &references);
700   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
701                                  lsp::Position hoverPos);
702   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
703 
704 private:
705   /// Find the MLIR document that contains the given position, and update the
706   /// position to be anchored at the start of the found chunk instead of the
707   /// beginning of the file.
708   MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
709 
710   /// The context used to hold the state contained by the parsed document.
711   MLIRContext context;
712 
713   /// The full string contents of the file.
714   std::string contents;
715 
716   /// The version of this file.
717   int64_t version;
718 
719   /// The number of lines in the file.
720   int64_t totalNumLines = 0;
721 
722   /// The chunks of this file. The order of these chunks is the order in which
723   /// they appear in the text file.
724   std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
725 };
726 } // namespace
727 
728 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
729                            int64_t version, DialectRegistry &registry,
730                            std::vector<lsp::Diagnostic> &diagnostics)
731     : context(registry, MLIRContext::Threading::DISABLED),
732       contents(fileContents.str()), version(version) {
733   context.allowUnregisteredDialects();
734 
735   // Split the file into separate MLIR documents.
736   // TODO: Find a way to share the split file marker with other tools. We don't
737   // want to use `splitAndProcessBuffer` here, but we do want to make sure this
738   // marker doesn't go out of sync.
739   SmallVector<StringRef, 8> subContents;
740   StringRef(contents).split(subContents, "// -----");
741   chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
742       context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
743 
744   uint64_t lineOffset = subContents.front().count('\n');
745   for (StringRef docContents : llvm::drop_begin(subContents)) {
746     unsigned currentNumDiags = diagnostics.size();
747     auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
748                                                      docContents, diagnostics);
749     lineOffset += docContents.count('\n');
750 
751     // Adjust locations used in diagnostics to account for the offset from the
752     // beginning of the file.
753     for (lsp::Diagnostic &diag :
754          llvm::drop_begin(diagnostics, currentNumDiags)) {
755       chunk->adjustLocForChunkOffset(diag.range);
756 
757       if (!diag.relatedInformation)
758         continue;
759       for (auto &it : *diag.relatedInformation)
760         if (it.location.uri == uri)
761           chunk->adjustLocForChunkOffset(it.location.range);
762     }
763     chunks.emplace_back(std::move(chunk));
764   }
765   totalNumLines = lineOffset;
766 }
767 
768 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
769                                   lsp::Position defPos,
770                                   std::vector<lsp::Location> &locations) {
771   MLIRTextFileChunk &chunk = getChunkFor(defPos);
772   chunk.document.getLocationsOf(uri, defPos, locations);
773 
774   // Adjust any locations within this file for the offset of this chunk.
775   if (chunk.lineOffset == 0)
776     return;
777   for (lsp::Location &loc : locations)
778     if (loc.uri == uri)
779       chunk.adjustLocForChunkOffset(loc.range);
780 }
781 
782 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
783                                     lsp::Position pos,
784                                     std::vector<lsp::Location> &references) {
785   MLIRTextFileChunk &chunk = getChunkFor(pos);
786   chunk.document.findReferencesOf(uri, pos, references);
787 
788   // Adjust any locations within this file for the offset of this chunk.
789   if (chunk.lineOffset == 0)
790     return;
791   for (lsp::Location &loc : references)
792     if (loc.uri == uri)
793       chunk.adjustLocForChunkOffset(loc.range);
794 }
795 
796 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
797                                              lsp::Position hoverPos) {
798   MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
799   Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
800 
801   // Adjust any locations within this file for the offset of this chunk.
802   if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
803     chunk.adjustLocForChunkOffset(*hoverInfo->range);
804   return hoverInfo;
805 }
806 
807 void MLIRTextFile::findDocumentSymbols(
808     std::vector<lsp::DocumentSymbol> &symbols) {
809   if (chunks.size() == 1)
810     return chunks.front()->document.findDocumentSymbols(symbols);
811 
812   // If there are multiple chunks in this file, we create top-level symbols for
813   // each chunk.
814   for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
815     MLIRTextFileChunk &chunk = *chunks[i];
816     lsp::Position startPos(chunk.lineOffset);
817     lsp::Position endPos((i == e - 1) ? totalNumLines - 1
818                                       : chunks[i + 1]->lineOffset);
819     lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
820                                lsp::SymbolKind::Namespace,
821                                /*range=*/lsp::Range(startPos, endPos),
822                                /*selectionRange=*/lsp::Range(startPos));
823     chunk.document.findDocumentSymbols(symbol.children);
824 
825     // Fixup the locations of document symbols within this chunk.
826     if (i != 0) {
827       SmallVector<lsp::DocumentSymbol *> symbolsToFix;
828       for (lsp::DocumentSymbol &childSymbol : symbol.children)
829         symbolsToFix.push_back(&childSymbol);
830 
831       while (!symbolsToFix.empty()) {
832         lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
833         chunk.adjustLocForChunkOffset(symbol->range);
834         chunk.adjustLocForChunkOffset(symbol->selectionRange);
835 
836         for (lsp::DocumentSymbol &childSymbol : symbol->children)
837           symbolsToFix.push_back(&childSymbol);
838       }
839     }
840 
841     // Push the symbol for this chunk.
842     symbols.emplace_back(std::move(symbol));
843   }
844 }
845 
846 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
847   if (chunks.size() == 1)
848     return *chunks.front();
849 
850   // Search for the first chunk with a greater line offset, the previous chunk
851   // is the one that contains `pos`.
852   auto it = llvm::upper_bound(
853       chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
854         return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
855       });
856   MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
857   pos.line -= chunk.lineOffset;
858   return chunk;
859 }
860 
861 //===----------------------------------------------------------------------===//
862 // MLIRServer::Impl
863 //===----------------------------------------------------------------------===//
864 
865 struct lsp::MLIRServer::Impl {
866   Impl(DialectRegistry &registry) : registry(registry) {}
867 
868   /// The registry containing dialects that can be recognized in parsed .mlir
869   /// files.
870   DialectRegistry &registry;
871 
872   /// The files held by the server, mapped by their URI file name.
873   llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
874 };
875 
876 //===----------------------------------------------------------------------===//
877 // MLIRServer
878 //===----------------------------------------------------------------------===//
879 
880 lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
881     : impl(std::make_unique<Impl>(registry)) {}
882 lsp::MLIRServer::~MLIRServer() = default;
883 
884 void lsp::MLIRServer::addOrUpdateDocument(
885     const URIForFile &uri, StringRef contents, int64_t version,
886     std::vector<Diagnostic> &diagnostics) {
887   impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
888       uri, contents, version, impl->registry, diagnostics);
889 }
890 
891 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
892   auto it = impl->files.find(uri.file());
893   if (it == impl->files.end())
894     return llvm::None;
895 
896   int64_t version = it->second->getVersion();
897   impl->files.erase(it);
898   return version;
899 }
900 
901 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
902                                      const Position &defPos,
903                                      std::vector<Location> &locations) {
904   auto fileIt = impl->files.find(uri.file());
905   if (fileIt != impl->files.end())
906     fileIt->second->getLocationsOf(uri, defPos, locations);
907 }
908 
909 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
910                                        const Position &pos,
911                                        std::vector<Location> &references) {
912   auto fileIt = impl->files.find(uri.file());
913   if (fileIt != impl->files.end())
914     fileIt->second->findReferencesOf(uri, pos, references);
915 }
916 
917 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
918                                                 const Position &hoverPos) {
919   auto fileIt = impl->files.find(uri.file());
920   if (fileIt != impl->files.end())
921     return fileIt->second->findHover(uri, hoverPos);
922   return llvm::None;
923 }
924 
925 void lsp::MLIRServer::findDocumentSymbols(
926     const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
927   auto fileIt = impl->files.find(uri.file());
928   if (fileIt != impl->files.end())
929     fileIt->second->findDocumentSymbols(symbols);
930 }
931