1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "../lsp-server-support/Logging.h"
11 #include "../lsp-server-support/Protocol.h"
12 #include "../lsp-server-support/SourceMgrUtils.h"
13 #include "mlir/IR/FunctionInterfaces.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Parser/AsmParserState.h"
16 #include "mlir/Parser/CodeComplete.h"
17 #include "mlir/Parser/Parser.h"
18 #include "llvm/Support/SourceMgr.h"
19 
20 using namespace mlir;
21 
22 /// Returns a language server location from the given MLIR file location.
23 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) {
24   llvm::Expected<lsp::URIForFile> sourceURI =
25       lsp::URIForFile::fromFile(loc.getFilename());
26   if (!sourceURI) {
27     lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
28                        loc.getFilename(),
29                        llvm::toString(sourceURI.takeError()));
30     return llvm::None;
31   }
32 
33   lsp::Position position;
34   position.line = loc.getLine() - 1;
35   position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
36   return lsp::Location{*sourceURI, lsp::Range(position)};
37 }
38 
39 /// Returns a language server location from the given MLIR location, or None if
40 /// one couldn't be created. `uri` is an optional additional filter that, when
41 /// present, is used to filter sub locations that do not share the same uri.
42 static Optional<lsp::Location>
43 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
44                    const lsp::URIForFile *uri = nullptr) {
45   Optional<lsp::Location> location;
46   loc->walk([&](Location nestedLoc) {
47     FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
48     if (!fileLoc)
49       return WalkResult::advance();
50 
51     Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
52     if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
53       location = *sourceLoc;
54       SMLoc loc = sourceMgr.FindLocForLineAndColumn(
55           sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
56 
57       // Use range of potential identifier starting at location, else length 1
58       // range.
59       location->range.end.character += 1;
60       if (Optional<SMRange> range = lsp::convertTokenLocToRange(loc)) {
61         auto lineCol = sourceMgr.getLineAndColumn(range->End);
62         location->range.end.character =
63             std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
64       }
65       return WalkResult::interrupt();
66     }
67     return WalkResult::advance();
68   });
69   return location;
70 }
71 
72 /// Collect all of the locations from the given MLIR location that are not
73 /// contained within the given URI.
74 static void collectLocationsFromLoc(Location loc,
75                                     std::vector<lsp::Location> &locations,
76                                     const lsp::URIForFile &uri) {
77   SetVector<Location> visitedLocs;
78   loc->walk([&](Location nestedLoc) {
79     FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
80     if (!fileLoc || !visitedLocs.insert(nestedLoc))
81       return WalkResult::advance();
82 
83     Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
84     if (sourceLoc && sourceLoc->uri != uri)
85       locations.push_back(*sourceLoc);
86     return WalkResult::advance();
87   });
88 }
89 
90 /// Returns true if the given range contains the given source location. Note
91 /// that this has slightly different behavior than SMRange because it is
92 /// inclusive of the end location.
93 static bool contains(SMRange range, SMLoc loc) {
94   return range.Start.getPointer() <= loc.getPointer() &&
95          loc.getPointer() <= range.End.getPointer();
96 }
97 
98 /// Returns true if the given location is contained by the definition or one of
99 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
100 /// the range within `def` that the provided `loc` overlapped with.
101 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
102                        SMRange *overlappedRange = nullptr) {
103   // Check the main definition.
104   if (contains(def.loc, loc)) {
105     if (overlappedRange)
106       *overlappedRange = def.loc;
107     return true;
108   }
109 
110   // Check the uses.
111   const auto *useIt = llvm::find_if(
112       def.uses, [&](const SMRange &range) { return contains(range, loc); });
113   if (useIt != def.uses.end()) {
114     if (overlappedRange)
115       *overlappedRange = *useIt;
116     return true;
117   }
118   return false;
119 }
120 
121 /// Given a location pointing to a result, return the result number it refers
122 /// to or None if it refers to all of the results.
123 static Optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
124   // Skip all of the identifier characters.
125   auto isIdentifierChar = [](char c) {
126     return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
127            c == '-';
128   };
129   const char *curPtr = loc.getPointer();
130   while (isIdentifierChar(*curPtr))
131     ++curPtr;
132 
133   // Check to see if this location indexes into the result group, via `#`. If it
134   // doesn't, we can't extract a sub result number.
135   if (*curPtr != '#')
136     return llvm::None;
137 
138   // Compute the sub result number from the remaining portion of the string.
139   const char *numberStart = ++curPtr;
140   while (llvm::isDigit(*curPtr))
141     ++curPtr;
142   StringRef numberStr(numberStart, curPtr - numberStart);
143   unsigned resultNumber = 0;
144   return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>()
145                                                     : resultNumber;
146 }
147 
148 /// Given a source location range, return the text covered by the given range.
149 /// If the range is invalid, returns None.
150 static Optional<StringRef> getTextFromRange(SMRange range) {
151   if (!range.isValid())
152     return None;
153   const char *startPtr = range.Start.getPointer();
154   return StringRef(startPtr, range.End.getPointer() - startPtr);
155 }
156 
157 /// Given a block, return its position in its parent region.
158 static unsigned getBlockNumber(Block *block) {
159   return std::distance(block->getParent()->begin(), block->getIterator());
160 }
161 
162 /// Given a block and source location, print the source name of the block to the
163 /// given output stream.
164 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
165   // Try to extract a name from the source location.
166   Optional<StringRef> text = getTextFromRange(loc);
167   if (text && text->startswith("^")) {
168     os << *text;
169     return;
170   }
171 
172   // Otherwise, we don't have a name so print the block number.
173   os << "<Block #" << getBlockNumber(block) << ">";
174 }
175 static void printDefBlockName(raw_ostream &os,
176                               const AsmParserState::BlockDefinition &def) {
177   printDefBlockName(os, def.block, def.definition.loc);
178 }
179 
180 /// Convert the given MLIR diagnostic to the LSP form.
181 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
182                                                Diagnostic &diag,
183                                                const lsp::URIForFile &uri) {
184   lsp::Diagnostic lspDiag;
185   lspDiag.source = "mlir";
186 
187   // Note: Right now all of the diagnostics are treated as parser issues, but
188   // some are parser and some are verifier.
189   lspDiag.category = "Parse Error";
190 
191   // Try to grab a file location for this diagnostic.
192   // TODO: For simplicity, we just grab the first one. It may be likely that we
193   // will need a more interesting heuristic here.'
194   Optional<lsp::Location> lspLocation =
195       getLocationFromLoc(sourceMgr, diag.getLocation(), &uri);
196   if (lspLocation)
197     lspDiag.range = lspLocation->range;
198 
199   // Convert the severity for the diagnostic.
200   switch (diag.getSeverity()) {
201   case DiagnosticSeverity::Note:
202     llvm_unreachable("expected notes to be handled separately");
203   case DiagnosticSeverity::Warning:
204     lspDiag.severity = lsp::DiagnosticSeverity::Warning;
205     break;
206   case DiagnosticSeverity::Error:
207     lspDiag.severity = lsp::DiagnosticSeverity::Error;
208     break;
209   case DiagnosticSeverity::Remark:
210     lspDiag.severity = lsp::DiagnosticSeverity::Information;
211     break;
212   }
213   lspDiag.message = diag.str();
214 
215   // Attach any notes to the main diagnostic as related information.
216   std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
217   for (Diagnostic &note : diag.getNotes()) {
218     lsp::Location noteLoc;
219     if (Optional<lsp::Location> loc =
220             getLocationFromLoc(sourceMgr, note.getLocation()))
221       noteLoc = *loc;
222     else
223       noteLoc.uri = uri;
224     relatedDiags.emplace_back(noteLoc, note.str());
225   }
226   if (!relatedDiags.empty())
227     lspDiag.relatedInformation = std::move(relatedDiags);
228 
229   return lspDiag;
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // MLIRDocument
234 //===----------------------------------------------------------------------===//
235 
236 namespace {
237 /// This class represents all of the information pertaining to a specific MLIR
238 /// document.
239 struct MLIRDocument {
240   MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
241                StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
242   MLIRDocument(const MLIRDocument &) = delete;
243   MLIRDocument &operator=(const MLIRDocument &) = delete;
244 
245   //===--------------------------------------------------------------------===//
246   // Definitions and References
247   //===--------------------------------------------------------------------===//
248 
249   void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
250                       std::vector<lsp::Location> &locations);
251   void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
252                         std::vector<lsp::Location> &references);
253 
254   //===--------------------------------------------------------------------===//
255   // Hover
256   //===--------------------------------------------------------------------===//
257 
258   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
259                                  const lsp::Position &hoverPos);
260   Optional<lsp::Hover>
261   buildHoverForOperation(SMRange hoverRange,
262                          const AsmParserState::OperationDefinition &op);
263   lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
264                                           unsigned resultStart,
265                                           unsigned resultEnd, SMLoc posLoc);
266   lsp::Hover buildHoverForBlock(SMRange hoverRange,
267                                 const AsmParserState::BlockDefinition &block);
268   lsp::Hover
269   buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
270                              const AsmParserState::BlockDefinition &block);
271 
272   //===--------------------------------------------------------------------===//
273   // Document Symbols
274   //===--------------------------------------------------------------------===//
275 
276   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
277   void findDocumentSymbols(Operation *op,
278                            std::vector<lsp::DocumentSymbol> &symbols);
279 
280   //===--------------------------------------------------------------------===//
281   // Code Completion
282   //===--------------------------------------------------------------------===//
283 
284   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
285                                         const lsp::Position &completePos,
286                                         const DialectRegistry &registry);
287 
288   //===--------------------------------------------------------------------===//
289   // Fields
290   //===--------------------------------------------------------------------===//
291 
292   /// The high level parser state used to find definitions and references within
293   /// the source file.
294   AsmParserState asmState;
295 
296   /// The container for the IR parsed from the input file.
297   Block parsedIR;
298 
299   /// The source manager containing the contents of the input file.
300   llvm::SourceMgr sourceMgr;
301 };
302 } // namespace
303 
304 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
305                            StringRef contents,
306                            std::vector<lsp::Diagnostic> &diagnostics) {
307   ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
308     diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
309   });
310 
311   // Try to parsed the given IR string.
312   auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
313   if (!memBuffer) {
314     lsp::Logger::error("Failed to create memory buffer for file", uri.file());
315     return;
316   }
317 
318   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
319   if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr,
320                              &asmState))) {
321     // If parsing failed, clear out any of the current state.
322     parsedIR.clear();
323     asmState = AsmParserState();
324     return;
325   }
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // MLIRDocument: Definitions and References
330 //===----------------------------------------------------------------------===//
331 
332 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
333                                   const lsp::Position &defPos,
334                                   std::vector<lsp::Location> &locations) {
335   SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
336 
337   // Functor used to check if an SM definition contains the position.
338   auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
339     if (!isDefOrUse(def, posLoc))
340       return false;
341     locations.emplace_back(uri, sourceMgr, def.loc);
342     return true;
343   };
344 
345   // Check all definitions related to operations.
346   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
347     if (contains(op.loc, posLoc))
348       return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
349     for (const auto &result : op.resultGroups)
350       if (containsPosition(result.definition))
351         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
352     for (const auto &symUse : op.symbolUses) {
353       if (contains(symUse, posLoc)) {
354         locations.emplace_back(uri, sourceMgr, op.loc);
355         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
356       }
357     }
358   }
359 
360   // Check all definitions related to blocks.
361   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
362     if (containsPosition(block.definition))
363       return;
364     for (const AsmParserState::SMDefinition &arg : block.arguments)
365       if (containsPosition(arg))
366         return;
367   }
368 }
369 
370 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
371                                     const lsp::Position &pos,
372                                     std::vector<lsp::Location> &references) {
373   // Functor used to append all of the definitions/uses of the given SM
374   // definition to the reference list.
375   auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
376     references.emplace_back(uri, sourceMgr, def.loc);
377     for (const SMRange &use : def.uses)
378       references.emplace_back(uri, sourceMgr, use);
379   };
380 
381   SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
382 
383   // Check all definitions related to operations.
384   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
385     if (contains(op.loc, posLoc)) {
386       for (const auto &result : op.resultGroups)
387         appendSMDef(result.definition);
388       for (const auto &symUse : op.symbolUses)
389         if (contains(symUse, posLoc))
390           references.emplace_back(uri, sourceMgr, symUse);
391       return;
392     }
393     for (const auto &result : op.resultGroups)
394       if (isDefOrUse(result.definition, posLoc))
395         return appendSMDef(result.definition);
396     for (const auto &symUse : op.symbolUses) {
397       if (!contains(symUse, posLoc))
398         continue;
399       for (const auto &symUse : op.symbolUses)
400         references.emplace_back(uri, sourceMgr, symUse);
401       return;
402     }
403   }
404 
405   // Check all definitions related to blocks.
406   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
407     if (isDefOrUse(block.definition, posLoc))
408       return appendSMDef(block.definition);
409 
410     for (const AsmParserState::SMDefinition &arg : block.arguments)
411       if (isDefOrUse(arg, posLoc))
412         return appendSMDef(arg);
413   }
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // MLIRDocument: Hover
418 //===----------------------------------------------------------------------===//
419 
420 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
421                                              const lsp::Position &hoverPos) {
422   SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
423   SMRange hoverRange;
424 
425   // Check for Hovers on operations and results.
426   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
427     // Check if the position points at this operation.
428     if (contains(op.loc, posLoc))
429       return buildHoverForOperation(op.loc, op);
430 
431     // Check if the position points at the symbol name.
432     for (auto &use : op.symbolUses)
433       if (contains(use, posLoc))
434         return buildHoverForOperation(use, op);
435 
436     // Check if the position points at a result group.
437     for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
438       const auto &result = op.resultGroups[i];
439       if (!isDefOrUse(result.definition, posLoc, &hoverRange))
440         continue;
441 
442       // Get the range of results covered by the over position.
443       unsigned resultStart = result.startIndex;
444       unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
445                                         : op.resultGroups[i + 1].startIndex;
446       return buildHoverForOperationResult(hoverRange, op.op, resultStart,
447                                           resultEnd, posLoc);
448     }
449   }
450 
451   // Check to see if the hover is over a block argument.
452   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
453     if (isDefOrUse(block.definition, posLoc, &hoverRange))
454       return buildHoverForBlock(hoverRange, block);
455 
456     for (const auto &arg : llvm::enumerate(block.arguments)) {
457       if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
458         continue;
459 
460       return buildHoverForBlockArgument(
461           hoverRange, block.block->getArgument(arg.index()), block);
462     }
463   }
464   return llvm::None;
465 }
466 
467 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
468     SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
469   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
470   llvm::raw_string_ostream os(hover.contents.value);
471 
472   // Add the operation name to the hover.
473   os << "\"" << op.op->getName() << "\"";
474   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
475     os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
476   os << "\n\n";
477 
478   os << "Generic Form:\n\n```mlir\n";
479 
480   // Temporary drop the regions of this operation so that they don't get
481   // printed in the output. This helps keeps the size of the output hover
482   // small.
483   SmallVector<std::unique_ptr<Region>> regions;
484   for (Region &region : op.op->getRegions()) {
485     regions.emplace_back(std::make_unique<Region>());
486     regions.back()->takeBody(region);
487   }
488 
489   op.op->print(
490       os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
491   os << "\n```\n";
492 
493   // Move the regions back to the current operation.
494   for (Region &region : op.op->getRegions())
495     region.takeBody(*regions.back());
496 
497   return hover;
498 }
499 
500 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
501                                                       Operation *op,
502                                                       unsigned resultStart,
503                                                       unsigned resultEnd,
504                                                       SMLoc posLoc) {
505   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
506   llvm::raw_string_ostream os(hover.contents.value);
507 
508   // Add the parent operation name to the hover.
509   os << "Operation: \"" << op->getName() << "\"\n\n";
510 
511   // Check to see if the location points to a specific result within the
512   // group.
513   if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
514     if ((resultStart + *resultNumber) < resultEnd) {
515       resultStart += *resultNumber;
516       resultEnd = resultStart + 1;
517     }
518   }
519 
520   // Add the range of results and their types to the hover info.
521   if ((resultStart + 1) == resultEnd) {
522     os << "Result #" << resultStart << "\n\n"
523        << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
524   } else {
525     os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
526        << "Types: ";
527     llvm::interleaveComma(
528         op->getResults().slice(resultStart, resultEnd), os,
529         [&](Value result) { os << "`" << result.getType() << "`"; });
530   }
531 
532   return hover;
533 }
534 
535 lsp::Hover
536 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
537                                  const AsmParserState::BlockDefinition &block) {
538   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
539   llvm::raw_string_ostream os(hover.contents.value);
540 
541   // Print the given block to the hover output stream.
542   auto printBlockToHover = [&](Block *newBlock) {
543     if (const auto *def = asmState.getBlockDef(newBlock))
544       printDefBlockName(os, *def);
545     else
546       printDefBlockName(os, newBlock);
547   };
548 
549   // Display the parent operation, block number, predecessors, and successors.
550   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
551      << "Block #" << getBlockNumber(block.block) << "\n\n";
552   if (!block.block->hasNoPredecessors()) {
553     os << "Predecessors: ";
554     llvm::interleaveComma(block.block->getPredecessors(), os,
555                           printBlockToHover);
556     os << "\n\n";
557   }
558   if (!block.block->hasNoSuccessors()) {
559     os << "Successors: ";
560     llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
561     os << "\n\n";
562   }
563 
564   return hover;
565 }
566 
567 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
568     SMRange hoverRange, BlockArgument arg,
569     const AsmParserState::BlockDefinition &block) {
570   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
571   llvm::raw_string_ostream os(hover.contents.value);
572 
573   // Display the parent operation, block, the argument number, and the type.
574   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
575      << "Block: ";
576   printDefBlockName(os, block);
577   os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
578      << "Type: `" << arg.getType() << "`\n\n";
579 
580   return hover;
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // MLIRDocument: Document Symbols
585 //===----------------------------------------------------------------------===//
586 
587 void MLIRDocument::findDocumentSymbols(
588     std::vector<lsp::DocumentSymbol> &symbols) {
589   for (Operation &op : parsedIR)
590     findDocumentSymbols(&op, symbols);
591 }
592 
593 void MLIRDocument::findDocumentSymbols(
594     Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
595   std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
596 
597   // Check for the source information of this operation.
598   if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
599     // If this operation defines a symbol, record it.
600     if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
601       symbols.emplace_back(symbol.getName(),
602                            isa<FunctionOpInterface>(op)
603                                ? lsp::SymbolKind::Function
604                                : lsp::SymbolKind::Class,
605                            lsp::Range(sourceMgr, def->scopeLoc),
606                            lsp::Range(sourceMgr, def->loc));
607       childSymbols = &symbols.back().children;
608 
609     } else if (op->hasTrait<OpTrait::SymbolTable>()) {
610       // Otherwise, if this is a symbol table push an anonymous document symbol.
611       symbols.emplace_back("<" + op->getName().getStringRef() + ">",
612                            lsp::SymbolKind::Namespace,
613                            lsp::Range(sourceMgr, def->scopeLoc),
614                            lsp::Range(sourceMgr, def->loc));
615       childSymbols = &symbols.back().children;
616     }
617   }
618 
619   // Recurse into the regions of this operation.
620   if (!op->getNumRegions())
621     return;
622   for (Region &region : op->getRegions())
623     for (Operation &childOp : region.getOps())
624       findDocumentSymbols(&childOp, *childSymbols);
625 }
626 
627 //===----------------------------------------------------------------------===//
628 // MLIRDocument: Code Completion
629 //===----------------------------------------------------------------------===//
630 
631 namespace {
632 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
633 public:
634   LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
635                          MLIRContext *ctx)
636       : AsmParserCodeCompleteContext(completeLoc),
637         completionList(completionList), ctx(ctx) {}
638 
639   /// Signal code completion for a dialect name.
640   void completeDialectName() final {
641     for (StringRef dialect : ctx->getAvailableDialects()) {
642       lsp::CompletionItem item(dialect, lsp::CompletionItemKind::Module);
643       item.sortText = "2";
644       item.detail = "dialect";
645       completionList.items.emplace_back(item);
646     }
647   }
648 
649   /// Signal code completion for an operation name within the given dialect.
650   void completeOperationName(StringRef dialectName) final {
651     Dialect *dialect = ctx->getOrLoadDialect(dialectName);
652     if (!dialect)
653       return;
654 
655     for (const auto &op : ctx->getRegisteredOperations()) {
656       if (&op.getDialect() != dialect)
657         continue;
658 
659       lsp::CompletionItem item(
660           op.getStringRef().drop_front(dialectName.size() + 1),
661           lsp::CompletionItemKind::Field);
662       item.sortText = "1";
663       item.detail = "operation";
664       completionList.items.emplace_back(item);
665     }
666   }
667 
668   /// Append the given SSA value as a code completion result for SSA value
669   /// completions.
670   void appendSSAValueCompletion(StringRef name, std::string typeData) final {
671     // Check if we need to insert the `%` or not.
672     bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
673 
674     lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
675     if (stripPrefix)
676       item.insertText = name.drop_front(1).str();
677     item.detail = std::move(typeData);
678     completionList.items.emplace_back(item);
679   }
680 
681   /// Append the given block as a code completion result for block name
682   /// completions.
683   void appendBlockCompletion(StringRef name) final {
684     // Check if we need to insert the `^` or not.
685     bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
686 
687     lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
688     if (stripPrefix)
689       item.insertText = name.drop_front(1).str();
690     completionList.items.emplace_back(item);
691   }
692 
693 private:
694   lsp::CompletionList &completionList;
695   MLIRContext *ctx;
696 };
697 } // namespace
698 
699 lsp::CompletionList
700 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
701                                 const lsp::Position &completePos,
702                                 const DialectRegistry &registry) {
703   SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
704   if (!posLoc.isValid())
705     return lsp::CompletionList();
706 
707   // To perform code completion, we run another parse of the module with the
708   // code completion context provided.
709   MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
710   tmpContext.allowUnregisteredDialects();
711   lsp::CompletionList completionList;
712   LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
713                                             &tmpContext);
714 
715   Block tmpIR;
716   AsmParserState tmpState;
717   (void)parseSourceFile(sourceMgr, &tmpIR, &tmpContext,
718                         /*sourceFileLoc=*/nullptr, &tmpState,
719                         &lspCompleteContext);
720   return completionList;
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // MLIRTextFileChunk
725 //===----------------------------------------------------------------------===//
726 
727 namespace {
728 /// This class represents a single chunk of an MLIR text file.
729 struct MLIRTextFileChunk {
730   MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
731                     const lsp::URIForFile &uri, StringRef contents,
732                     std::vector<lsp::Diagnostic> &diagnostics)
733       : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
734 
735   /// Adjust the line number of the given range to anchor at the beginning of
736   /// the file, instead of the beginning of this chunk.
737   void adjustLocForChunkOffset(lsp::Range &range) {
738     adjustLocForChunkOffset(range.start);
739     adjustLocForChunkOffset(range.end);
740   }
741   /// Adjust the line number of the given position to anchor at the beginning of
742   /// the file, instead of the beginning of this chunk.
743   void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
744 
745   /// The line offset of this chunk from the beginning of the file.
746   uint64_t lineOffset;
747   /// The document referred to by this chunk.
748   MLIRDocument document;
749 };
750 } // namespace
751 
752 //===----------------------------------------------------------------------===//
753 // MLIRTextFile
754 //===----------------------------------------------------------------------===//
755 
756 namespace {
757 /// This class represents a text file containing one or more MLIR documents.
758 class MLIRTextFile {
759 public:
760   MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
761                int64_t version, DialectRegistry &registry,
762                std::vector<lsp::Diagnostic> &diagnostics);
763 
764   /// Return the current version of this text file.
765   int64_t getVersion() const { return version; }
766 
767   //===--------------------------------------------------------------------===//
768   // LSP Queries
769   //===--------------------------------------------------------------------===//
770 
771   void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
772                       std::vector<lsp::Location> &locations);
773   void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
774                         std::vector<lsp::Location> &references);
775   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
776                                  lsp::Position hoverPos);
777   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
778   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
779                                         lsp::Position completePos);
780 
781 private:
782   /// Find the MLIR document that contains the given position, and update the
783   /// position to be anchored at the start of the found chunk instead of the
784   /// beginning of the file.
785   MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
786 
787   /// The context used to hold the state contained by the parsed document.
788   MLIRContext context;
789 
790   /// The full string contents of the file.
791   std::string contents;
792 
793   /// The version of this file.
794   int64_t version;
795 
796   /// The number of lines in the file.
797   int64_t totalNumLines = 0;
798 
799   /// The chunks of this file. The order of these chunks is the order in which
800   /// they appear in the text file.
801   std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
802 };
803 } // namespace
804 
805 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
806                            int64_t version, DialectRegistry &registry,
807                            std::vector<lsp::Diagnostic> &diagnostics)
808     : context(registry, MLIRContext::Threading::DISABLED),
809       contents(fileContents.str()), version(version) {
810   context.allowUnregisteredDialects();
811 
812   // Split the file into separate MLIR documents.
813   // TODO: Find a way to share the split file marker with other tools. We don't
814   // want to use `splitAndProcessBuffer` here, but we do want to make sure this
815   // marker doesn't go out of sync.
816   SmallVector<StringRef, 8> subContents;
817   StringRef(contents).split(subContents, "// -----");
818   chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
819       context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
820 
821   uint64_t lineOffset = subContents.front().count('\n');
822   for (StringRef docContents : llvm::drop_begin(subContents)) {
823     unsigned currentNumDiags = diagnostics.size();
824     auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
825                                                      docContents, diagnostics);
826     lineOffset += docContents.count('\n');
827 
828     // Adjust locations used in diagnostics to account for the offset from the
829     // beginning of the file.
830     for (lsp::Diagnostic &diag :
831          llvm::drop_begin(diagnostics, currentNumDiags)) {
832       chunk->adjustLocForChunkOffset(diag.range);
833 
834       if (!diag.relatedInformation)
835         continue;
836       for (auto &it : *diag.relatedInformation)
837         if (it.location.uri == uri)
838           chunk->adjustLocForChunkOffset(it.location.range);
839     }
840     chunks.emplace_back(std::move(chunk));
841   }
842   totalNumLines = lineOffset;
843 }
844 
845 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
846                                   lsp::Position defPos,
847                                   std::vector<lsp::Location> &locations) {
848   MLIRTextFileChunk &chunk = getChunkFor(defPos);
849   chunk.document.getLocationsOf(uri, defPos, locations);
850 
851   // Adjust any locations within this file for the offset of this chunk.
852   if (chunk.lineOffset == 0)
853     return;
854   for (lsp::Location &loc : locations)
855     if (loc.uri == uri)
856       chunk.adjustLocForChunkOffset(loc.range);
857 }
858 
859 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
860                                     lsp::Position pos,
861                                     std::vector<lsp::Location> &references) {
862   MLIRTextFileChunk &chunk = getChunkFor(pos);
863   chunk.document.findReferencesOf(uri, pos, references);
864 
865   // Adjust any locations within this file for the offset of this chunk.
866   if (chunk.lineOffset == 0)
867     return;
868   for (lsp::Location &loc : references)
869     if (loc.uri == uri)
870       chunk.adjustLocForChunkOffset(loc.range);
871 }
872 
873 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
874                                              lsp::Position hoverPos) {
875   MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
876   Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
877 
878   // Adjust any locations within this file for the offset of this chunk.
879   if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
880     chunk.adjustLocForChunkOffset(*hoverInfo->range);
881   return hoverInfo;
882 }
883 
884 void MLIRTextFile::findDocumentSymbols(
885     std::vector<lsp::DocumentSymbol> &symbols) {
886   if (chunks.size() == 1)
887     return chunks.front()->document.findDocumentSymbols(symbols);
888 
889   // If there are multiple chunks in this file, we create top-level symbols for
890   // each chunk.
891   for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
892     MLIRTextFileChunk &chunk = *chunks[i];
893     lsp::Position startPos(chunk.lineOffset);
894     lsp::Position endPos((i == e - 1) ? totalNumLines - 1
895                                       : chunks[i + 1]->lineOffset);
896     lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
897                                lsp::SymbolKind::Namespace,
898                                /*range=*/lsp::Range(startPos, endPos),
899                                /*selectionRange=*/lsp::Range(startPos));
900     chunk.document.findDocumentSymbols(symbol.children);
901 
902     // Fixup the locations of document symbols within this chunk.
903     if (i != 0) {
904       SmallVector<lsp::DocumentSymbol *> symbolsToFix;
905       for (lsp::DocumentSymbol &childSymbol : symbol.children)
906         symbolsToFix.push_back(&childSymbol);
907 
908       while (!symbolsToFix.empty()) {
909         lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
910         chunk.adjustLocForChunkOffset(symbol->range);
911         chunk.adjustLocForChunkOffset(symbol->selectionRange);
912 
913         for (lsp::DocumentSymbol &childSymbol : symbol->children)
914           symbolsToFix.push_back(&childSymbol);
915       }
916     }
917 
918     // Push the symbol for this chunk.
919     symbols.emplace_back(std::move(symbol));
920   }
921 }
922 
923 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
924                                                     lsp::Position completePos) {
925   MLIRTextFileChunk &chunk = getChunkFor(completePos);
926   lsp::CompletionList completionList = chunk.document.getCodeCompletion(
927       uri, completePos, context.getDialectRegistry());
928 
929   // Adjust any completion locations.
930   for (lsp::CompletionItem &item : completionList.items) {
931     if (item.textEdit)
932       chunk.adjustLocForChunkOffset(item.textEdit->range);
933     for (lsp::TextEdit &edit : item.additionalTextEdits)
934       chunk.adjustLocForChunkOffset(edit.range);
935   }
936   return completionList;
937 }
938 
939 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
940   if (chunks.size() == 1)
941     return *chunks.front();
942 
943   // Search for the first chunk with a greater line offset, the previous chunk
944   // is the one that contains `pos`.
945   auto it = llvm::upper_bound(
946       chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
947         return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
948       });
949   MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
950   pos.line -= chunk.lineOffset;
951   return chunk;
952 }
953 
954 //===----------------------------------------------------------------------===//
955 // MLIRServer::Impl
956 //===----------------------------------------------------------------------===//
957 
958 struct lsp::MLIRServer::Impl {
959   Impl(DialectRegistry &registry) : registry(registry) {}
960 
961   /// The registry containing dialects that can be recognized in parsed .mlir
962   /// files.
963   DialectRegistry &registry;
964 
965   /// The files held by the server, mapped by their URI file name.
966   llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
967 };
968 
969 //===----------------------------------------------------------------------===//
970 // MLIRServer
971 //===----------------------------------------------------------------------===//
972 
973 lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
974     : impl(std::make_unique<Impl>(registry)) {}
975 lsp::MLIRServer::~MLIRServer() = default;
976 
977 void lsp::MLIRServer::addOrUpdateDocument(
978     const URIForFile &uri, StringRef contents, int64_t version,
979     std::vector<Diagnostic> &diagnostics) {
980   impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
981       uri, contents, version, impl->registry, diagnostics);
982 }
983 
984 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
985   auto it = impl->files.find(uri.file());
986   if (it == impl->files.end())
987     return llvm::None;
988 
989   int64_t version = it->second->getVersion();
990   impl->files.erase(it);
991   return version;
992 }
993 
994 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
995                                      const Position &defPos,
996                                      std::vector<Location> &locations) {
997   auto fileIt = impl->files.find(uri.file());
998   if (fileIt != impl->files.end())
999     fileIt->second->getLocationsOf(uri, defPos, locations);
1000 }
1001 
1002 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
1003                                        const Position &pos,
1004                                        std::vector<Location> &references) {
1005   auto fileIt = impl->files.find(uri.file());
1006   if (fileIt != impl->files.end())
1007     fileIt->second->findReferencesOf(uri, pos, references);
1008 }
1009 
1010 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1011                                                 const Position &hoverPos) {
1012   auto fileIt = impl->files.find(uri.file());
1013   if (fileIt != impl->files.end())
1014     return fileIt->second->findHover(uri, hoverPos);
1015   return llvm::None;
1016 }
1017 
1018 void lsp::MLIRServer::findDocumentSymbols(
1019     const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1020   auto fileIt = impl->files.find(uri.file());
1021   if (fileIt != impl->files.end())
1022     fileIt->second->findDocumentSymbols(symbols);
1023 }
1024 
1025 lsp::CompletionList
1026 lsp::MLIRServer::getCodeCompletion(const URIForFile &uri,
1027                                    const Position &completePos) {
1028   auto fileIt = impl->files.find(uri.file());
1029   if (fileIt != impl->files.end())
1030     return fileIt->second->getCodeCompletion(uri, completePos);
1031   return CompletionList();
1032 }
1033