1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "MLIRServer.h"
10 #include "../lsp-server-support/Logging.h"
11 #include "../lsp-server-support/Protocol.h"
12 #include "../lsp-server-support/SourceMgrUtils.h"
13 #include "mlir/AsmParser/AsmParser.h"
14 #include "mlir/AsmParser/AsmParserState.h"
15 #include "mlir/AsmParser/CodeComplete.h"
16 #include "mlir/IR/FunctionInterfaces.h"
17 #include "mlir/IR/Operation.h"
18 #include "llvm/Support/SourceMgr.h"
19
20 using namespace mlir;
21
22 /// Returns a language server location from the given MLIR file location.
getLocationFromLoc(FileLineColLoc loc)23 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) {
24 llvm::Expected<lsp::URIForFile> sourceURI =
25 lsp::URIForFile::fromFile(loc.getFilename());
26 if (!sourceURI) {
27 lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
28 loc.getFilename(),
29 llvm::toString(sourceURI.takeError()));
30 return llvm::None;
31 }
32
33 lsp::Position position;
34 position.line = loc.getLine() - 1;
35 position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
36 return lsp::Location{*sourceURI, lsp::Range(position)};
37 }
38
39 /// Returns a language server location from the given MLIR location, or None if
40 /// one couldn't be created. `uri` is an optional additional filter that, when
41 /// present, is used to filter sub locations that do not share the same uri.
42 static Optional<lsp::Location>
getLocationFromLoc(llvm::SourceMgr & sourceMgr,Location loc,const lsp::URIForFile * uri=nullptr)43 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
44 const lsp::URIForFile *uri = nullptr) {
45 Optional<lsp::Location> location;
46 loc->walk([&](Location nestedLoc) {
47 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
48 if (!fileLoc)
49 return WalkResult::advance();
50
51 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
52 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
53 location = *sourceLoc;
54 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
55 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
56
57 // Use range of potential identifier starting at location, else length 1
58 // range.
59 location->range.end.character += 1;
60 if (Optional<SMRange> range = lsp::convertTokenLocToRange(loc)) {
61 auto lineCol = sourceMgr.getLineAndColumn(range->End);
62 location->range.end.character =
63 std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
64 }
65 return WalkResult::interrupt();
66 }
67 return WalkResult::advance();
68 });
69 return location;
70 }
71
72 /// Collect all of the locations from the given MLIR location that are not
73 /// contained within the given URI.
collectLocationsFromLoc(Location loc,std::vector<lsp::Location> & locations,const lsp::URIForFile & uri)74 static void collectLocationsFromLoc(Location loc,
75 std::vector<lsp::Location> &locations,
76 const lsp::URIForFile &uri) {
77 SetVector<Location> visitedLocs;
78 loc->walk([&](Location nestedLoc) {
79 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
80 if (!fileLoc || !visitedLocs.insert(nestedLoc))
81 return WalkResult::advance();
82
83 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
84 if (sourceLoc && sourceLoc->uri != uri)
85 locations.push_back(*sourceLoc);
86 return WalkResult::advance();
87 });
88 }
89
90 /// Returns true if the given range contains the given source location. Note
91 /// that this has slightly different behavior than SMRange because it is
92 /// inclusive of the end location.
contains(SMRange range,SMLoc loc)93 static bool contains(SMRange range, SMLoc loc) {
94 return range.Start.getPointer() <= loc.getPointer() &&
95 loc.getPointer() <= range.End.getPointer();
96 }
97
98 /// Returns true if the given location is contained by the definition or one of
99 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
100 /// the range within `def` that the provided `loc` overlapped with.
isDefOrUse(const AsmParserState::SMDefinition & def,SMLoc loc,SMRange * overlappedRange=nullptr)101 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
102 SMRange *overlappedRange = nullptr) {
103 // Check the main definition.
104 if (contains(def.loc, loc)) {
105 if (overlappedRange)
106 *overlappedRange = def.loc;
107 return true;
108 }
109
110 // Check the uses.
111 const auto *useIt = llvm::find_if(
112 def.uses, [&](const SMRange &range) { return contains(range, loc); });
113 if (useIt != def.uses.end()) {
114 if (overlappedRange)
115 *overlappedRange = *useIt;
116 return true;
117 }
118 return false;
119 }
120
121 /// Given a location pointing to a result, return the result number it refers
122 /// to or None if it refers to all of the results.
getResultNumberFromLoc(SMLoc loc)123 static Optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
124 // Skip all of the identifier characters.
125 auto isIdentifierChar = [](char c) {
126 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
127 c == '-';
128 };
129 const char *curPtr = loc.getPointer();
130 while (isIdentifierChar(*curPtr))
131 ++curPtr;
132
133 // Check to see if this location indexes into the result group, via `#`. If it
134 // doesn't, we can't extract a sub result number.
135 if (*curPtr != '#')
136 return llvm::None;
137
138 // Compute the sub result number from the remaining portion of the string.
139 const char *numberStart = ++curPtr;
140 while (llvm::isDigit(*curPtr))
141 ++curPtr;
142 StringRef numberStr(numberStart, curPtr - numberStart);
143 unsigned resultNumber = 0;
144 return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>()
145 : resultNumber;
146 }
147
148 /// Given a source location range, return the text covered by the given range.
149 /// If the range is invalid, returns None.
getTextFromRange(SMRange range)150 static Optional<StringRef> getTextFromRange(SMRange range) {
151 if (!range.isValid())
152 return None;
153 const char *startPtr = range.Start.getPointer();
154 return StringRef(startPtr, range.End.getPointer() - startPtr);
155 }
156
157 /// Given a block, return its position in its parent region.
getBlockNumber(Block * block)158 static unsigned getBlockNumber(Block *block) {
159 return std::distance(block->getParent()->begin(), block->getIterator());
160 }
161
162 /// Given a block and source location, print the source name of the block to the
163 /// given output stream.
printDefBlockName(raw_ostream & os,Block * block,SMRange loc={})164 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
165 // Try to extract a name from the source location.
166 Optional<StringRef> text = getTextFromRange(loc);
167 if (text && text->startswith("^")) {
168 os << *text;
169 return;
170 }
171
172 // Otherwise, we don't have a name so print the block number.
173 os << "<Block #" << getBlockNumber(block) << ">";
174 }
printDefBlockName(raw_ostream & os,const AsmParserState::BlockDefinition & def)175 static void printDefBlockName(raw_ostream &os,
176 const AsmParserState::BlockDefinition &def) {
177 printDefBlockName(os, def.block, def.definition.loc);
178 }
179
180 /// Convert the given MLIR diagnostic to the LSP form.
getLspDiagnoticFromDiag(llvm::SourceMgr & sourceMgr,Diagnostic & diag,const lsp::URIForFile & uri)181 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
182 Diagnostic &diag,
183 const lsp::URIForFile &uri) {
184 lsp::Diagnostic lspDiag;
185 lspDiag.source = "mlir";
186
187 // Note: Right now all of the diagnostics are treated as parser issues, but
188 // some are parser and some are verifier.
189 lspDiag.category = "Parse Error";
190
191 // Try to grab a file location for this diagnostic.
192 // TODO: For simplicity, we just grab the first one. It may be likely that we
193 // will need a more interesting heuristic here.'
194 Optional<lsp::Location> lspLocation =
195 getLocationFromLoc(sourceMgr, diag.getLocation(), &uri);
196 if (lspLocation)
197 lspDiag.range = lspLocation->range;
198
199 // Convert the severity for the diagnostic.
200 switch (diag.getSeverity()) {
201 case DiagnosticSeverity::Note:
202 llvm_unreachable("expected notes to be handled separately");
203 case DiagnosticSeverity::Warning:
204 lspDiag.severity = lsp::DiagnosticSeverity::Warning;
205 break;
206 case DiagnosticSeverity::Error:
207 lspDiag.severity = lsp::DiagnosticSeverity::Error;
208 break;
209 case DiagnosticSeverity::Remark:
210 lspDiag.severity = lsp::DiagnosticSeverity::Information;
211 break;
212 }
213 lspDiag.message = diag.str();
214
215 // Attach any notes to the main diagnostic as related information.
216 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
217 for (Diagnostic ¬e : diag.getNotes()) {
218 lsp::Location noteLoc;
219 if (Optional<lsp::Location> loc =
220 getLocationFromLoc(sourceMgr, note.getLocation()))
221 noteLoc = *loc;
222 else
223 noteLoc.uri = uri;
224 relatedDiags.emplace_back(noteLoc, note.str());
225 }
226 if (!relatedDiags.empty())
227 lspDiag.relatedInformation = std::move(relatedDiags);
228
229 return lspDiag;
230 }
231
232 //===----------------------------------------------------------------------===//
233 // MLIRDocument
234 //===----------------------------------------------------------------------===//
235
236 namespace {
237 /// This class represents all of the information pertaining to a specific MLIR
238 /// document.
239 struct MLIRDocument {
240 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
241 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
242 MLIRDocument(const MLIRDocument &) = delete;
243 MLIRDocument &operator=(const MLIRDocument &) = delete;
244
245 //===--------------------------------------------------------------------===//
246 // Definitions and References
247 //===--------------------------------------------------------------------===//
248
249 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
250 std::vector<lsp::Location> &locations);
251 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
252 std::vector<lsp::Location> &references);
253
254 //===--------------------------------------------------------------------===//
255 // Hover
256 //===--------------------------------------------------------------------===//
257
258 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
259 const lsp::Position &hoverPos);
260 Optional<lsp::Hover>
261 buildHoverForOperation(SMRange hoverRange,
262 const AsmParserState::OperationDefinition &op);
263 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
264 unsigned resultStart,
265 unsigned resultEnd, SMLoc posLoc);
266 lsp::Hover buildHoverForBlock(SMRange hoverRange,
267 const AsmParserState::BlockDefinition &block);
268 lsp::Hover
269 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
270 const AsmParserState::BlockDefinition &block);
271
272 //===--------------------------------------------------------------------===//
273 // Document Symbols
274 //===--------------------------------------------------------------------===//
275
276 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
277 void findDocumentSymbols(Operation *op,
278 std::vector<lsp::DocumentSymbol> &symbols);
279
280 //===--------------------------------------------------------------------===//
281 // Code Completion
282 //===--------------------------------------------------------------------===//
283
284 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
285 const lsp::Position &completePos,
286 const DialectRegistry ®istry);
287
288 //===--------------------------------------------------------------------===//
289 // Code Action
290 //===--------------------------------------------------------------------===//
291
292 void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
293 lsp::Position &pos, StringRef severity,
294 StringRef message,
295 std::vector<lsp::TextEdit> &edits);
296
297 //===--------------------------------------------------------------------===//
298 // Fields
299 //===--------------------------------------------------------------------===//
300
301 /// The high level parser state used to find definitions and references within
302 /// the source file.
303 AsmParserState asmState;
304
305 /// The container for the IR parsed from the input file.
306 Block parsedIR;
307
308 /// The source manager containing the contents of the input file.
309 llvm::SourceMgr sourceMgr;
310 };
311 } // namespace
312
MLIRDocument(MLIRContext & context,const lsp::URIForFile & uri,StringRef contents,std::vector<lsp::Diagnostic> & diagnostics)313 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
314 StringRef contents,
315 std::vector<lsp::Diagnostic> &diagnostics) {
316 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
317 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
318 });
319
320 // Try to parsed the given IR string.
321 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
322 if (!memBuffer) {
323 lsp::Logger::error("Failed to create memory buffer for file", uri.file());
324 return;
325 }
326
327 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
328 if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, &context, &asmState))) {
329 // If parsing failed, clear out any of the current state.
330 parsedIR.clear();
331 asmState = AsmParserState();
332 return;
333 }
334 }
335
336 //===----------------------------------------------------------------------===//
337 // MLIRDocument: Definitions and References
338 //===----------------------------------------------------------------------===//
339
getLocationsOf(const lsp::URIForFile & uri,const lsp::Position & defPos,std::vector<lsp::Location> & locations)340 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
341 const lsp::Position &defPos,
342 std::vector<lsp::Location> &locations) {
343 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
344
345 // Functor used to check if an SM definition contains the position.
346 auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
347 if (!isDefOrUse(def, posLoc))
348 return false;
349 locations.emplace_back(uri, sourceMgr, def.loc);
350 return true;
351 };
352
353 // Check all definitions related to operations.
354 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
355 if (contains(op.loc, posLoc))
356 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
357 for (const auto &result : op.resultGroups)
358 if (containsPosition(result.definition))
359 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
360 for (const auto &symUse : op.symbolUses) {
361 if (contains(symUse, posLoc)) {
362 locations.emplace_back(uri, sourceMgr, op.loc);
363 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
364 }
365 }
366 }
367
368 // Check all definitions related to blocks.
369 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
370 if (containsPosition(block.definition))
371 return;
372 for (const AsmParserState::SMDefinition &arg : block.arguments)
373 if (containsPosition(arg))
374 return;
375 }
376 }
377
findReferencesOf(const lsp::URIForFile & uri,const lsp::Position & pos,std::vector<lsp::Location> & references)378 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
379 const lsp::Position &pos,
380 std::vector<lsp::Location> &references) {
381 // Functor used to append all of the definitions/uses of the given SM
382 // definition to the reference list.
383 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
384 references.emplace_back(uri, sourceMgr, def.loc);
385 for (const SMRange &use : def.uses)
386 references.emplace_back(uri, sourceMgr, use);
387 };
388
389 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
390
391 // Check all definitions related to operations.
392 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
393 if (contains(op.loc, posLoc)) {
394 for (const auto &result : op.resultGroups)
395 appendSMDef(result.definition);
396 for (const auto &symUse : op.symbolUses)
397 if (contains(symUse, posLoc))
398 references.emplace_back(uri, sourceMgr, symUse);
399 return;
400 }
401 for (const auto &result : op.resultGroups)
402 if (isDefOrUse(result.definition, posLoc))
403 return appendSMDef(result.definition);
404 for (const auto &symUse : op.symbolUses) {
405 if (!contains(symUse, posLoc))
406 continue;
407 for (const auto &symUse : op.symbolUses)
408 references.emplace_back(uri, sourceMgr, symUse);
409 return;
410 }
411 }
412
413 // Check all definitions related to blocks.
414 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
415 if (isDefOrUse(block.definition, posLoc))
416 return appendSMDef(block.definition);
417
418 for (const AsmParserState::SMDefinition &arg : block.arguments)
419 if (isDefOrUse(arg, posLoc))
420 return appendSMDef(arg);
421 }
422 }
423
424 //===----------------------------------------------------------------------===//
425 // MLIRDocument: Hover
426 //===----------------------------------------------------------------------===//
427
findHover(const lsp::URIForFile & uri,const lsp::Position & hoverPos)428 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
429 const lsp::Position &hoverPos) {
430 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
431 SMRange hoverRange;
432
433 // Check for Hovers on operations and results.
434 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
435 // Check if the position points at this operation.
436 if (contains(op.loc, posLoc))
437 return buildHoverForOperation(op.loc, op);
438
439 // Check if the position points at the symbol name.
440 for (auto &use : op.symbolUses)
441 if (contains(use, posLoc))
442 return buildHoverForOperation(use, op);
443
444 // Check if the position points at a result group.
445 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
446 const auto &result = op.resultGroups[i];
447 if (!isDefOrUse(result.definition, posLoc, &hoverRange))
448 continue;
449
450 // Get the range of results covered by the over position.
451 unsigned resultStart = result.startIndex;
452 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
453 : op.resultGroups[i + 1].startIndex;
454 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
455 resultEnd, posLoc);
456 }
457 }
458
459 // Check to see if the hover is over a block argument.
460 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
461 if (isDefOrUse(block.definition, posLoc, &hoverRange))
462 return buildHoverForBlock(hoverRange, block);
463
464 for (const auto &arg : llvm::enumerate(block.arguments)) {
465 if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
466 continue;
467
468 return buildHoverForBlockArgument(
469 hoverRange, block.block->getArgument(arg.index()), block);
470 }
471 }
472 return llvm::None;
473 }
474
buildHoverForOperation(SMRange hoverRange,const AsmParserState::OperationDefinition & op)475 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
476 SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
477 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
478 llvm::raw_string_ostream os(hover.contents.value);
479
480 // Add the operation name to the hover.
481 os << "\"" << op.op->getName() << "\"";
482 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
483 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
484 os << "\n\n";
485
486 os << "Generic Form:\n\n```mlir\n";
487
488 // Temporary drop the regions of this operation so that they don't get
489 // printed in the output. This helps keeps the size of the output hover
490 // small.
491 SmallVector<std::unique_ptr<Region>> regions;
492 for (Region ®ion : op.op->getRegions()) {
493 regions.emplace_back(std::make_unique<Region>());
494 regions.back()->takeBody(region);
495 }
496
497 op.op->print(
498 os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
499 os << "\n```\n";
500
501 // Move the regions back to the current operation.
502 for (Region ®ion : op.op->getRegions())
503 region.takeBody(*regions.back());
504
505 return hover;
506 }
507
buildHoverForOperationResult(SMRange hoverRange,Operation * op,unsigned resultStart,unsigned resultEnd,SMLoc posLoc)508 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
509 Operation *op,
510 unsigned resultStart,
511 unsigned resultEnd,
512 SMLoc posLoc) {
513 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
514 llvm::raw_string_ostream os(hover.contents.value);
515
516 // Add the parent operation name to the hover.
517 os << "Operation: \"" << op->getName() << "\"\n\n";
518
519 // Check to see if the location points to a specific result within the
520 // group.
521 if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
522 if ((resultStart + *resultNumber) < resultEnd) {
523 resultStart += *resultNumber;
524 resultEnd = resultStart + 1;
525 }
526 }
527
528 // Add the range of results and their types to the hover info.
529 if ((resultStart + 1) == resultEnd) {
530 os << "Result #" << resultStart << "\n\n"
531 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
532 } else {
533 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
534 << "Types: ";
535 llvm::interleaveComma(
536 op->getResults().slice(resultStart, resultEnd), os,
537 [&](Value result) { os << "`" << result.getType() << "`"; });
538 }
539
540 return hover;
541 }
542
543 lsp::Hover
buildHoverForBlock(SMRange hoverRange,const AsmParserState::BlockDefinition & block)544 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
545 const AsmParserState::BlockDefinition &block) {
546 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
547 llvm::raw_string_ostream os(hover.contents.value);
548
549 // Print the given block to the hover output stream.
550 auto printBlockToHover = [&](Block *newBlock) {
551 if (const auto *def = asmState.getBlockDef(newBlock))
552 printDefBlockName(os, *def);
553 else
554 printDefBlockName(os, newBlock);
555 };
556
557 // Display the parent operation, block number, predecessors, and successors.
558 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
559 << "Block #" << getBlockNumber(block.block) << "\n\n";
560 if (!block.block->hasNoPredecessors()) {
561 os << "Predecessors: ";
562 llvm::interleaveComma(block.block->getPredecessors(), os,
563 printBlockToHover);
564 os << "\n\n";
565 }
566 if (!block.block->hasNoSuccessors()) {
567 os << "Successors: ";
568 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
569 os << "\n\n";
570 }
571
572 return hover;
573 }
574
buildHoverForBlockArgument(SMRange hoverRange,BlockArgument arg,const AsmParserState::BlockDefinition & block)575 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
576 SMRange hoverRange, BlockArgument arg,
577 const AsmParserState::BlockDefinition &block) {
578 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
579 llvm::raw_string_ostream os(hover.contents.value);
580
581 // Display the parent operation, block, the argument number, and the type.
582 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
583 << "Block: ";
584 printDefBlockName(os, block);
585 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
586 << "Type: `" << arg.getType() << "`\n\n";
587
588 return hover;
589 }
590
591 //===----------------------------------------------------------------------===//
592 // MLIRDocument: Document Symbols
593 //===----------------------------------------------------------------------===//
594
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)595 void MLIRDocument::findDocumentSymbols(
596 std::vector<lsp::DocumentSymbol> &symbols) {
597 for (Operation &op : parsedIR)
598 findDocumentSymbols(&op, symbols);
599 }
600
findDocumentSymbols(Operation * op,std::vector<lsp::DocumentSymbol> & symbols)601 void MLIRDocument::findDocumentSymbols(
602 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
603 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
604
605 // Check for the source information of this operation.
606 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
607 // If this operation defines a symbol, record it.
608 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
609 symbols.emplace_back(symbol.getName(),
610 isa<FunctionOpInterface>(op)
611 ? lsp::SymbolKind::Function
612 : lsp::SymbolKind::Class,
613 lsp::Range(sourceMgr, def->scopeLoc),
614 lsp::Range(sourceMgr, def->loc));
615 childSymbols = &symbols.back().children;
616
617 } else if (op->hasTrait<OpTrait::SymbolTable>()) {
618 // Otherwise, if this is a symbol table push an anonymous document symbol.
619 symbols.emplace_back("<" + op->getName().getStringRef() + ">",
620 lsp::SymbolKind::Namespace,
621 lsp::Range(sourceMgr, def->scopeLoc),
622 lsp::Range(sourceMgr, def->loc));
623 childSymbols = &symbols.back().children;
624 }
625 }
626
627 // Recurse into the regions of this operation.
628 if (!op->getNumRegions())
629 return;
630 for (Region ®ion : op->getRegions())
631 for (Operation &childOp : region.getOps())
632 findDocumentSymbols(&childOp, *childSymbols);
633 }
634
635 //===----------------------------------------------------------------------===//
636 // MLIRDocument: Code Completion
637 //===----------------------------------------------------------------------===//
638
639 namespace {
640 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
641 public:
LSPCodeCompleteContext(SMLoc completeLoc,lsp::CompletionList & completionList,MLIRContext * ctx)642 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
643 MLIRContext *ctx)
644 : AsmParserCodeCompleteContext(completeLoc),
645 completionList(completionList), ctx(ctx) {}
646
647 /// Signal code completion for a dialect name, with an optional prefix.
completeDialectName(StringRef prefix)648 void completeDialectName(StringRef prefix) final {
649 for (StringRef dialect : ctx->getAvailableDialects()) {
650 lsp::CompletionItem item(prefix + dialect,
651 lsp::CompletionItemKind::Module,
652 /*sortText=*/"3");
653 item.detail = "dialect";
654 completionList.items.emplace_back(item);
655 }
656 }
657 using AsmParserCodeCompleteContext::completeDialectName;
658
659 /// Signal code completion for an operation name within the given dialect.
completeOperationName(StringRef dialectName)660 void completeOperationName(StringRef dialectName) final {
661 Dialect *dialect = ctx->getOrLoadDialect(dialectName);
662 if (!dialect)
663 return;
664
665 for (const auto &op : ctx->getRegisteredOperations()) {
666 if (&op.getDialect() != dialect)
667 continue;
668
669 lsp::CompletionItem item(
670 op.getStringRef().drop_front(dialectName.size() + 1),
671 lsp::CompletionItemKind::Field,
672 /*sortText=*/"1");
673 item.detail = "operation";
674 completionList.items.emplace_back(item);
675 }
676 }
677
678 /// Append the given SSA value as a code completion result for SSA value
679 /// completions.
appendSSAValueCompletion(StringRef name,std::string typeData)680 void appendSSAValueCompletion(StringRef name, std::string typeData) final {
681 // Check if we need to insert the `%` or not.
682 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
683
684 lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
685 if (stripPrefix)
686 item.insertText = name.drop_front(1).str();
687 item.detail = std::move(typeData);
688 completionList.items.emplace_back(item);
689 }
690
691 /// Append the given block as a code completion result for block name
692 /// completions.
appendBlockCompletion(StringRef name)693 void appendBlockCompletion(StringRef name) final {
694 // Check if we need to insert the `^` or not.
695 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
696
697 lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
698 if (stripPrefix)
699 item.insertText = name.drop_front(1).str();
700 completionList.items.emplace_back(item);
701 }
702
703 /// Signal a completion for the given expected token.
completeExpectedTokens(ArrayRef<StringRef> tokens,bool optional)704 void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
705 for (StringRef token : tokens) {
706 lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
707 /*sortText=*/"0");
708 item.detail = optional ? "optional" : "";
709 completionList.items.emplace_back(item);
710 }
711 }
712
713 /// Signal a completion for an attribute.
completeAttribute(const llvm::StringMap<Attribute> & aliases)714 void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
715 appendSimpleCompletions({"affine_set", "affine_map", "dense", "false",
716 "loc", "opaque", "sparse", "true", "unit"},
717 lsp::CompletionItemKind::Field,
718 /*sortText=*/"1");
719
720 completeDialectName("#");
721 completeAliases(aliases, "#");
722 }
completeDialectAttributeOrAlias(const llvm::StringMap<Attribute> & aliases)723 void completeDialectAttributeOrAlias(
724 const llvm::StringMap<Attribute> &aliases) override {
725 completeDialectName();
726 completeAliases(aliases);
727 }
728
729 /// Signal a completion for a type.
completeType(const llvm::StringMap<Type> & aliases)730 void completeType(const llvm::StringMap<Type> &aliases) override {
731 // Handle the various builtin types.
732 appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
733 "bf16", "f16", "f32", "f64", "f80", "f128",
734 "index", "none"},
735 lsp::CompletionItemKind::Field,
736 /*sortText=*/"1");
737
738 // Handle the builtin integer types.
739 for (StringRef type : {"i", "si", "ui"}) {
740 lsp::CompletionItem item(type + "<N>", lsp::CompletionItemKind::Field,
741 /*sortText=*/"1");
742 item.insertText = type.str();
743 completionList.items.emplace_back(item);
744 }
745
746 // Insert completions for dialect types and aliases.
747 completeDialectName("!");
748 completeAliases(aliases, "!");
749 }
750 void
completeDialectTypeOrAlias(const llvm::StringMap<Type> & aliases)751 completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
752 completeDialectName();
753 completeAliases(aliases);
754 }
755
756 /// Add completion results for the given set of aliases.
757 template <typename T>
completeAliases(const llvm::StringMap<T> & aliases,StringRef prefix="")758 void completeAliases(const llvm::StringMap<T> &aliases,
759 StringRef prefix = "") {
760 for (const auto &alias : aliases) {
761 lsp::CompletionItem item(prefix + alias.getKey(),
762 lsp::CompletionItemKind::Field,
763 /*sortText=*/"2");
764 llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
765 completionList.items.emplace_back(item);
766 }
767 }
768
769 /// Add a set of simple completions that all have the same kind.
appendSimpleCompletions(ArrayRef<StringRef> completions,lsp::CompletionItemKind kind,StringRef sortText="")770 void appendSimpleCompletions(ArrayRef<StringRef> completions,
771 lsp::CompletionItemKind kind,
772 StringRef sortText = "") {
773 for (StringRef completion : completions)
774 completionList.items.emplace_back(completion, kind, sortText);
775 }
776
777 private:
778 lsp::CompletionList &completionList;
779 MLIRContext *ctx;
780 };
781 } // namespace
782
783 lsp::CompletionList
getCodeCompletion(const lsp::URIForFile & uri,const lsp::Position & completePos,const DialectRegistry & registry)784 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
785 const lsp::Position &completePos,
786 const DialectRegistry ®istry) {
787 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
788 if (!posLoc.isValid())
789 return lsp::CompletionList();
790
791 // To perform code completion, we run another parse of the module with the
792 // code completion context provided.
793 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
794 tmpContext.allowUnregisteredDialects();
795 lsp::CompletionList completionList;
796 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
797 &tmpContext);
798
799 Block tmpIR;
800 AsmParserState tmpState;
801 (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
802 &lspCompleteContext);
803 return completionList;
804 }
805
806 //===----------------------------------------------------------------------===//
807 // MLIRDocument: Code Action
808 //===----------------------------------------------------------------------===//
809
getCodeActionForDiagnostic(const lsp::URIForFile & uri,lsp::Position & pos,StringRef severity,StringRef message,std::vector<lsp::TextEdit> & edits)810 void MLIRDocument::getCodeActionForDiagnostic(
811 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
812 StringRef message, std::vector<lsp::TextEdit> &edits) {
813 // Ignore diagnostics that print the current operation. These are always
814 // enabled for the language server, but not generally during normal
815 // parsing/verification.
816 if (message.startswith("see current operation: "))
817 return;
818
819 // Get the start of the line containing the diagnostic.
820 const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
821 const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
822 if (!lineStart)
823 return;
824 StringRef line(lineStart, pos.character);
825
826 // Add a text edit for adding an expected-* diagnostic check for this
827 // diagnostic.
828 lsp::TextEdit edit;
829 edit.range = lsp::Range(lsp::Position(pos.line, 0));
830
831 // Use the indent of the current line for the expected-* diagnostic.
832 size_t indent = line.find_first_not_of(" ");
833 if (indent == StringRef::npos)
834 indent = line.size();
835
836 edit.newText.append(indent, ' ');
837 llvm::raw_string_ostream(edit.newText)
838 << "// expected-" << severity << " @below {{" << message << "}}\n";
839 edits.emplace_back(std::move(edit));
840 }
841
842 //===----------------------------------------------------------------------===//
843 // MLIRTextFileChunk
844 //===----------------------------------------------------------------------===//
845
846 namespace {
847 /// This class represents a single chunk of an MLIR text file.
848 struct MLIRTextFileChunk {
MLIRTextFileChunk__anone0f8adc70c11::MLIRTextFileChunk849 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
850 const lsp::URIForFile &uri, StringRef contents,
851 std::vector<lsp::Diagnostic> &diagnostics)
852 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
853
854 /// Adjust the line number of the given range to anchor at the beginning of
855 /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anone0f8adc70c11::MLIRTextFileChunk856 void adjustLocForChunkOffset(lsp::Range &range) {
857 adjustLocForChunkOffset(range.start);
858 adjustLocForChunkOffset(range.end);
859 }
860 /// Adjust the line number of the given position to anchor at the beginning of
861 /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anone0f8adc70c11::MLIRTextFileChunk862 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
863
864 /// The line offset of this chunk from the beginning of the file.
865 uint64_t lineOffset;
866 /// The document referred to by this chunk.
867 MLIRDocument document;
868 };
869 } // namespace
870
871 //===----------------------------------------------------------------------===//
872 // MLIRTextFile
873 //===----------------------------------------------------------------------===//
874
875 namespace {
876 /// This class represents a text file containing one or more MLIR documents.
877 class MLIRTextFile {
878 public:
879 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
880 int64_t version, DialectRegistry ®istry,
881 std::vector<lsp::Diagnostic> &diagnostics);
882
883 /// Return the current version of this text file.
getVersion() const884 int64_t getVersion() const { return version; }
885
886 //===--------------------------------------------------------------------===//
887 // LSP Queries
888 //===--------------------------------------------------------------------===//
889
890 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
891 std::vector<lsp::Location> &locations);
892 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
893 std::vector<lsp::Location> &references);
894 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
895 lsp::Position hoverPos);
896 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
897 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
898 lsp::Position completePos);
899 void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
900 const lsp::CodeActionContext &context,
901 std::vector<lsp::CodeAction> &actions);
902
903 private:
904 /// Find the MLIR document that contains the given position, and update the
905 /// position to be anchored at the start of the found chunk instead of the
906 /// beginning of the file.
907 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
908
909 /// The context used to hold the state contained by the parsed document.
910 MLIRContext context;
911
912 /// The full string contents of the file.
913 std::string contents;
914
915 /// The version of this file.
916 int64_t version;
917
918 /// The number of lines in the file.
919 int64_t totalNumLines = 0;
920
921 /// The chunks of this file. The order of these chunks is the order in which
922 /// they appear in the text file.
923 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
924 };
925 } // namespace
926
MLIRTextFile(const lsp::URIForFile & uri,StringRef fileContents,int64_t version,DialectRegistry & registry,std::vector<lsp::Diagnostic> & diagnostics)927 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
928 int64_t version, DialectRegistry ®istry,
929 std::vector<lsp::Diagnostic> &diagnostics)
930 : context(registry, MLIRContext::Threading::DISABLED),
931 contents(fileContents.str()), version(version) {
932 context.allowUnregisteredDialects();
933
934 // Split the file into separate MLIR documents.
935 // TODO: Find a way to share the split file marker with other tools. We don't
936 // want to use `splitAndProcessBuffer` here, but we do want to make sure this
937 // marker doesn't go out of sync.
938 SmallVector<StringRef, 8> subContents;
939 StringRef(contents).split(subContents, "// -----");
940 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
941 context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
942
943 uint64_t lineOffset = subContents.front().count('\n');
944 for (StringRef docContents : llvm::drop_begin(subContents)) {
945 unsigned currentNumDiags = diagnostics.size();
946 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
947 docContents, diagnostics);
948 lineOffset += docContents.count('\n');
949
950 // Adjust locations used in diagnostics to account for the offset from the
951 // beginning of the file.
952 for (lsp::Diagnostic &diag :
953 llvm::drop_begin(diagnostics, currentNumDiags)) {
954 chunk->adjustLocForChunkOffset(diag.range);
955
956 if (!diag.relatedInformation)
957 continue;
958 for (auto &it : *diag.relatedInformation)
959 if (it.location.uri == uri)
960 chunk->adjustLocForChunkOffset(it.location.range);
961 }
962 chunks.emplace_back(std::move(chunk));
963 }
964 totalNumLines = lineOffset;
965 }
966
getLocationsOf(const lsp::URIForFile & uri,lsp::Position defPos,std::vector<lsp::Location> & locations)967 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
968 lsp::Position defPos,
969 std::vector<lsp::Location> &locations) {
970 MLIRTextFileChunk &chunk = getChunkFor(defPos);
971 chunk.document.getLocationsOf(uri, defPos, locations);
972
973 // Adjust any locations within this file for the offset of this chunk.
974 if (chunk.lineOffset == 0)
975 return;
976 for (lsp::Location &loc : locations)
977 if (loc.uri == uri)
978 chunk.adjustLocForChunkOffset(loc.range);
979 }
980
findReferencesOf(const lsp::URIForFile & uri,lsp::Position pos,std::vector<lsp::Location> & references)981 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
982 lsp::Position pos,
983 std::vector<lsp::Location> &references) {
984 MLIRTextFileChunk &chunk = getChunkFor(pos);
985 chunk.document.findReferencesOf(uri, pos, references);
986
987 // Adjust any locations within this file for the offset of this chunk.
988 if (chunk.lineOffset == 0)
989 return;
990 for (lsp::Location &loc : references)
991 if (loc.uri == uri)
992 chunk.adjustLocForChunkOffset(loc.range);
993 }
994
findHover(const lsp::URIForFile & uri,lsp::Position hoverPos)995 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
996 lsp::Position hoverPos) {
997 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
998 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
999
1000 // Adjust any locations within this file for the offset of this chunk.
1001 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1002 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1003 return hoverInfo;
1004 }
1005
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)1006 void MLIRTextFile::findDocumentSymbols(
1007 std::vector<lsp::DocumentSymbol> &symbols) {
1008 if (chunks.size() == 1)
1009 return chunks.front()->document.findDocumentSymbols(symbols);
1010
1011 // If there are multiple chunks in this file, we create top-level symbols for
1012 // each chunk.
1013 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1014 MLIRTextFileChunk &chunk = *chunks[i];
1015 lsp::Position startPos(chunk.lineOffset);
1016 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1017 : chunks[i + 1]->lineOffset);
1018 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1019 lsp::SymbolKind::Namespace,
1020 /*range=*/lsp::Range(startPos, endPos),
1021 /*selectionRange=*/lsp::Range(startPos));
1022 chunk.document.findDocumentSymbols(symbol.children);
1023
1024 // Fixup the locations of document symbols within this chunk.
1025 if (i != 0) {
1026 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1027 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1028 symbolsToFix.push_back(&childSymbol);
1029
1030 while (!symbolsToFix.empty()) {
1031 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1032 chunk.adjustLocForChunkOffset(symbol->range);
1033 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1034
1035 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1036 symbolsToFix.push_back(&childSymbol);
1037 }
1038 }
1039
1040 // Push the symbol for this chunk.
1041 symbols.emplace_back(std::move(symbol));
1042 }
1043 }
1044
getCodeCompletion(const lsp::URIForFile & uri,lsp::Position completePos)1045 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1046 lsp::Position completePos) {
1047 MLIRTextFileChunk &chunk = getChunkFor(completePos);
1048 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1049 uri, completePos, context.getDialectRegistry());
1050
1051 // Adjust any completion locations.
1052 for (lsp::CompletionItem &item : completionList.items) {
1053 if (item.textEdit)
1054 chunk.adjustLocForChunkOffset(item.textEdit->range);
1055 for (lsp::TextEdit &edit : item.additionalTextEdits)
1056 chunk.adjustLocForChunkOffset(edit.range);
1057 }
1058 return completionList;
1059 }
1060
getCodeActions(const lsp::URIForFile & uri,const lsp::Range & pos,const lsp::CodeActionContext & context,std::vector<lsp::CodeAction> & actions)1061 void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1062 const lsp::Range &pos,
1063 const lsp::CodeActionContext &context,
1064 std::vector<lsp::CodeAction> &actions) {
1065 // Create actions for any diagnostics in this file.
1066 for (auto &diag : context.diagnostics) {
1067 if (diag.source != "mlir")
1068 continue;
1069 lsp::Position diagPos = diag.range.start;
1070 MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1071
1072 // Add a new code action that inserts a "expected" diagnostic check.
1073 lsp::CodeAction action;
1074 action.title = "Add expected-* diagnostic checks";
1075 action.kind = lsp::CodeAction::kQuickFix.str();
1076
1077 StringRef severity;
1078 switch (diag.severity) {
1079 case lsp::DiagnosticSeverity::Error:
1080 severity = "error";
1081 break;
1082 case lsp::DiagnosticSeverity::Warning:
1083 severity = "warning";
1084 break;
1085 default:
1086 continue;
1087 }
1088
1089 // Get edits for the diagnostic.
1090 std::vector<lsp::TextEdit> edits;
1091 chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1092 diag.message, edits);
1093
1094 // Walk the related diagnostics, this is how we encode notes.
1095 if (diag.relatedInformation) {
1096 for (auto ¬eDiag : *diag.relatedInformation) {
1097 if (noteDiag.location.uri != uri)
1098 continue;
1099 diagPos = noteDiag.location.range.start;
1100 diagPos.line -= chunk.lineOffset;
1101 chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1102 noteDiag.message, edits);
1103 }
1104 }
1105 // Fixup the locations for any edits.
1106 for (lsp::TextEdit &edit : edits)
1107 chunk.adjustLocForChunkOffset(edit.range);
1108
1109 action.edit.emplace();
1110 action.edit->changes[uri.uri().str()] = std::move(edits);
1111 action.diagnostics = {diag};
1112
1113 actions.emplace_back(std::move(action));
1114 }
1115 }
1116
getChunkFor(lsp::Position & pos)1117 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1118 if (chunks.size() == 1)
1119 return *chunks.front();
1120
1121 // Search for the first chunk with a greater line offset, the previous chunk
1122 // is the one that contains `pos`.
1123 auto it = llvm::upper_bound(
1124 chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1125 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1126 });
1127 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1128 pos.line -= chunk.lineOffset;
1129 return chunk;
1130 }
1131
1132 //===----------------------------------------------------------------------===//
1133 // MLIRServer::Impl
1134 //===----------------------------------------------------------------------===//
1135
1136 struct lsp::MLIRServer::Impl {
Impllsp::MLIRServer::Impl1137 Impl(DialectRegistry ®istry) : registry(registry) {}
1138
1139 /// The registry containing dialects that can be recognized in parsed .mlir
1140 /// files.
1141 DialectRegistry ®istry;
1142
1143 /// The files held by the server, mapped by their URI file name.
1144 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1145 };
1146
1147 //===----------------------------------------------------------------------===//
1148 // MLIRServer
1149 //===----------------------------------------------------------------------===//
1150
MLIRServer(DialectRegistry & registry)1151 lsp::MLIRServer::MLIRServer(DialectRegistry ®istry)
1152 : impl(std::make_unique<Impl>(registry)) {}
1153 lsp::MLIRServer::~MLIRServer() = default;
1154
addOrUpdateDocument(const URIForFile & uri,StringRef contents,int64_t version,std::vector<Diagnostic> & diagnostics)1155 void lsp::MLIRServer::addOrUpdateDocument(
1156 const URIForFile &uri, StringRef contents, int64_t version,
1157 std::vector<Diagnostic> &diagnostics) {
1158 impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1159 uri, contents, version, impl->registry, diagnostics);
1160 }
1161
removeDocument(const URIForFile & uri)1162 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1163 auto it = impl->files.find(uri.file());
1164 if (it == impl->files.end())
1165 return llvm::None;
1166
1167 int64_t version = it->second->getVersion();
1168 impl->files.erase(it);
1169 return version;
1170 }
1171
getLocationsOf(const URIForFile & uri,const Position & defPos,std::vector<Location> & locations)1172 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
1173 const Position &defPos,
1174 std::vector<Location> &locations) {
1175 auto fileIt = impl->files.find(uri.file());
1176 if (fileIt != impl->files.end())
1177 fileIt->second->getLocationsOf(uri, defPos, locations);
1178 }
1179
findReferencesOf(const URIForFile & uri,const Position & pos,std::vector<Location> & references)1180 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
1181 const Position &pos,
1182 std::vector<Location> &references) {
1183 auto fileIt = impl->files.find(uri.file());
1184 if (fileIt != impl->files.end())
1185 fileIt->second->findReferencesOf(uri, pos, references);
1186 }
1187
findHover(const URIForFile & uri,const Position & hoverPos)1188 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1189 const Position &hoverPos) {
1190 auto fileIt = impl->files.find(uri.file());
1191 if (fileIt != impl->files.end())
1192 return fileIt->second->findHover(uri, hoverPos);
1193 return llvm::None;
1194 }
1195
findDocumentSymbols(const URIForFile & uri,std::vector<DocumentSymbol> & symbols)1196 void lsp::MLIRServer::findDocumentSymbols(
1197 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1198 auto fileIt = impl->files.find(uri.file());
1199 if (fileIt != impl->files.end())
1200 fileIt->second->findDocumentSymbols(symbols);
1201 }
1202
1203 lsp::CompletionList
getCodeCompletion(const URIForFile & uri,const Position & completePos)1204 lsp::MLIRServer::getCodeCompletion(const URIForFile &uri,
1205 const Position &completePos) {
1206 auto fileIt = impl->files.find(uri.file());
1207 if (fileIt != impl->files.end())
1208 return fileIt->second->getCodeCompletion(uri, completePos);
1209 return CompletionList();
1210 }
1211
getCodeActions(const URIForFile & uri,const Range & pos,const CodeActionContext & context,std::vector<CodeAction> & actions)1212 void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
1213 const CodeActionContext &context,
1214 std::vector<CodeAction> &actions) {
1215 auto fileIt = impl->files.find(uri.file());
1216 if (fileIt != impl->files.end())
1217 fileIt->second->getCodeActions(uri, pos, context, actions);
1218 }
1219