1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MLIRServer.h" 10 #include "../lsp-server-support/Logging.h" 11 #include "../lsp-server-support/Protocol.h" 12 #include "../lsp-server-support/SourceMgrUtils.h" 13 #include "mlir/IR/FunctionInterfaces.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/Parser/AsmParserState.h" 16 #include "mlir/Parser/CodeComplete.h" 17 #include "mlir/Parser/Parser.h" 18 #include "llvm/Support/SourceMgr.h" 19 20 using namespace mlir; 21 22 /// Returns a language server location from the given MLIR file location. 23 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) { 24 llvm::Expected<lsp::URIForFile> sourceURI = 25 lsp::URIForFile::fromFile(loc.getFilename()); 26 if (!sourceURI) { 27 lsp::Logger::error("Failed to create URI for file `{0}`: {1}", 28 loc.getFilename(), 29 llvm::toString(sourceURI.takeError())); 30 return llvm::None; 31 } 32 33 lsp::Position position; 34 position.line = loc.getLine() - 1; 35 position.character = loc.getColumn() ? loc.getColumn() - 1 : 0; 36 return lsp::Location{*sourceURI, lsp::Range(position)}; 37 } 38 39 /// Returns a language server location from the given MLIR location, or None if 40 /// one couldn't be created. `uri` is an optional additional filter that, when 41 /// present, is used to filter sub locations that do not share the same uri. 42 static Optional<lsp::Location> 43 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, 44 const lsp::URIForFile *uri = nullptr) { 45 Optional<lsp::Location> location; 46 loc->walk([&](Location nestedLoc) { 47 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 48 if (!fileLoc) 49 return WalkResult::advance(); 50 51 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 52 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { 53 location = *sourceLoc; 54 SMLoc loc = sourceMgr.FindLocForLineAndColumn( 55 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn()); 56 57 // Use range of potential identifier starting at location, else length 1 58 // range. 59 location->range.end.character += 1; 60 if (Optional<SMRange> range = lsp::convertTokenLocToRange(loc)) { 61 auto lineCol = sourceMgr.getLineAndColumn(range->End); 62 location->range.end.character = 63 std::max(fileLoc.getColumn() + 1, lineCol.second - 1); 64 } 65 return WalkResult::interrupt(); 66 } 67 return WalkResult::advance(); 68 }); 69 return location; 70 } 71 72 /// Collect all of the locations from the given MLIR location that are not 73 /// contained within the given URI. 74 static void collectLocationsFromLoc(Location loc, 75 std::vector<lsp::Location> &locations, 76 const lsp::URIForFile &uri) { 77 SetVector<Location> visitedLocs; 78 loc->walk([&](Location nestedLoc) { 79 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 80 if (!fileLoc || !visitedLocs.insert(nestedLoc)) 81 return WalkResult::advance(); 82 83 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 84 if (sourceLoc && sourceLoc->uri != uri) 85 locations.push_back(*sourceLoc); 86 return WalkResult::advance(); 87 }); 88 } 89 90 /// Returns true if the given range contains the given source location. Note 91 /// that this has slightly different behavior than SMRange because it is 92 /// inclusive of the end location. 93 static bool contains(SMRange range, SMLoc loc) { 94 return range.Start.getPointer() <= loc.getPointer() && 95 loc.getPointer() <= range.End.getPointer(); 96 } 97 98 /// Returns true if the given location is contained by the definition or one of 99 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to 100 /// the range within `def` that the provided `loc` overlapped with. 101 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, 102 SMRange *overlappedRange = nullptr) { 103 // Check the main definition. 104 if (contains(def.loc, loc)) { 105 if (overlappedRange) 106 *overlappedRange = def.loc; 107 return true; 108 } 109 110 // Check the uses. 111 const auto *useIt = llvm::find_if( 112 def.uses, [&](const SMRange &range) { return contains(range, loc); }); 113 if (useIt != def.uses.end()) { 114 if (overlappedRange) 115 *overlappedRange = *useIt; 116 return true; 117 } 118 return false; 119 } 120 121 /// Given a location pointing to a result, return the result number it refers 122 /// to or None if it refers to all of the results. 123 static Optional<unsigned> getResultNumberFromLoc(SMLoc loc) { 124 // Skip all of the identifier characters. 125 auto isIdentifierChar = [](char c) { 126 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' || 127 c == '-'; 128 }; 129 const char *curPtr = loc.getPointer(); 130 while (isIdentifierChar(*curPtr)) 131 ++curPtr; 132 133 // Check to see if this location indexes into the result group, via `#`. If it 134 // doesn't, we can't extract a sub result number. 135 if (*curPtr != '#') 136 return llvm::None; 137 138 // Compute the sub result number from the remaining portion of the string. 139 const char *numberStart = ++curPtr; 140 while (llvm::isDigit(*curPtr)) 141 ++curPtr; 142 StringRef numberStr(numberStart, curPtr - numberStart); 143 unsigned resultNumber = 0; 144 return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>() 145 : resultNumber; 146 } 147 148 /// Given a source location range, return the text covered by the given range. 149 /// If the range is invalid, returns None. 150 static Optional<StringRef> getTextFromRange(SMRange range) { 151 if (!range.isValid()) 152 return None; 153 const char *startPtr = range.Start.getPointer(); 154 return StringRef(startPtr, range.End.getPointer() - startPtr); 155 } 156 157 /// Given a block, return its position in its parent region. 158 static unsigned getBlockNumber(Block *block) { 159 return std::distance(block->getParent()->begin(), block->getIterator()); 160 } 161 162 /// Given a block and source location, print the source name of the block to the 163 /// given output stream. 164 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) { 165 // Try to extract a name from the source location. 166 Optional<StringRef> text = getTextFromRange(loc); 167 if (text && text->startswith("^")) { 168 os << *text; 169 return; 170 } 171 172 // Otherwise, we don't have a name so print the block number. 173 os << "<Block #" << getBlockNumber(block) << ">"; 174 } 175 static void printDefBlockName(raw_ostream &os, 176 const AsmParserState::BlockDefinition &def) { 177 printDefBlockName(os, def.block, def.definition.loc); 178 } 179 180 /// Convert the given MLIR diagnostic to the LSP form. 181 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, 182 Diagnostic &diag, 183 const lsp::URIForFile &uri) { 184 lsp::Diagnostic lspDiag; 185 lspDiag.source = "mlir"; 186 187 // Note: Right now all of the diagnostics are treated as parser issues, but 188 // some are parser and some are verifier. 189 lspDiag.category = "Parse Error"; 190 191 // Try to grab a file location for this diagnostic. 192 // TODO: For simplicity, we just grab the first one. It may be likely that we 193 // will need a more interesting heuristic here.' 194 Optional<lsp::Location> lspLocation = 195 getLocationFromLoc(sourceMgr, diag.getLocation(), &uri); 196 if (lspLocation) 197 lspDiag.range = lspLocation->range; 198 199 // Convert the severity for the diagnostic. 200 switch (diag.getSeverity()) { 201 case DiagnosticSeverity::Note: 202 llvm_unreachable("expected notes to be handled separately"); 203 case DiagnosticSeverity::Warning: 204 lspDiag.severity = lsp::DiagnosticSeverity::Warning; 205 break; 206 case DiagnosticSeverity::Error: 207 lspDiag.severity = lsp::DiagnosticSeverity::Error; 208 break; 209 case DiagnosticSeverity::Remark: 210 lspDiag.severity = lsp::DiagnosticSeverity::Information; 211 break; 212 } 213 lspDiag.message = diag.str(); 214 215 // Attach any notes to the main diagnostic as related information. 216 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; 217 for (Diagnostic ¬e : diag.getNotes()) { 218 lsp::Location noteLoc; 219 if (Optional<lsp::Location> loc = 220 getLocationFromLoc(sourceMgr, note.getLocation())) 221 noteLoc = *loc; 222 else 223 noteLoc.uri = uri; 224 relatedDiags.emplace_back(noteLoc, note.str()); 225 } 226 if (!relatedDiags.empty()) 227 lspDiag.relatedInformation = std::move(relatedDiags); 228 229 return lspDiag; 230 } 231 232 //===----------------------------------------------------------------------===// 233 // MLIRDocument 234 //===----------------------------------------------------------------------===// 235 236 namespace { 237 /// This class represents all of the information pertaining to a specific MLIR 238 /// document. 239 struct MLIRDocument { 240 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 241 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics); 242 MLIRDocument(const MLIRDocument &) = delete; 243 MLIRDocument &operator=(const MLIRDocument &) = delete; 244 245 //===--------------------------------------------------------------------===// 246 // Definitions and References 247 //===--------------------------------------------------------------------===// 248 249 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, 250 std::vector<lsp::Location> &locations); 251 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, 252 std::vector<lsp::Location> &references); 253 254 //===--------------------------------------------------------------------===// 255 // Hover 256 //===--------------------------------------------------------------------===// 257 258 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 259 const lsp::Position &hoverPos); 260 Optional<lsp::Hover> 261 buildHoverForOperation(SMRange hoverRange, 262 const AsmParserState::OperationDefinition &op); 263 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op, 264 unsigned resultStart, 265 unsigned resultEnd, SMLoc posLoc); 266 lsp::Hover buildHoverForBlock(SMRange hoverRange, 267 const AsmParserState::BlockDefinition &block); 268 lsp::Hover 269 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg, 270 const AsmParserState::BlockDefinition &block); 271 272 //===--------------------------------------------------------------------===// 273 // Document Symbols 274 //===--------------------------------------------------------------------===// 275 276 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 277 void findDocumentSymbols(Operation *op, 278 std::vector<lsp::DocumentSymbol> &symbols); 279 280 //===--------------------------------------------------------------------===// 281 // Code Completion 282 //===--------------------------------------------------------------------===// 283 284 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, 285 const lsp::Position &completePos, 286 const DialectRegistry ®istry); 287 288 //===--------------------------------------------------------------------===// 289 // Fields 290 //===--------------------------------------------------------------------===// 291 292 /// The high level parser state used to find definitions and references within 293 /// the source file. 294 AsmParserState asmState; 295 296 /// The container for the IR parsed from the input file. 297 Block parsedIR; 298 299 /// The source manager containing the contents of the input file. 300 llvm::SourceMgr sourceMgr; 301 }; 302 } // namespace 303 304 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 305 StringRef contents, 306 std::vector<lsp::Diagnostic> &diagnostics) { 307 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { 308 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri)); 309 }); 310 311 // Try to parsed the given IR string. 312 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); 313 if (!memBuffer) { 314 lsp::Logger::error("Failed to create memory buffer for file", uri.file()); 315 return; 316 } 317 318 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 319 if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, 320 &asmState))) { 321 // If parsing failed, clear out any of the current state. 322 parsedIR.clear(); 323 asmState = AsmParserState(); 324 return; 325 } 326 } 327 328 //===----------------------------------------------------------------------===// 329 // MLIRDocument: Definitions and References 330 //===----------------------------------------------------------------------===// 331 332 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, 333 const lsp::Position &defPos, 334 std::vector<lsp::Location> &locations) { 335 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); 336 337 // Functor used to check if an SM definition contains the position. 338 auto containsPosition = [&](const AsmParserState::SMDefinition &def) { 339 if (!isDefOrUse(def, posLoc)) 340 return false; 341 locations.emplace_back(uri, sourceMgr, def.loc); 342 return true; 343 }; 344 345 // Check all definitions related to operations. 346 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 347 if (contains(op.loc, posLoc)) 348 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 349 for (const auto &result : op.resultGroups) 350 if (containsPosition(result.definition)) 351 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 352 for (const auto &symUse : op.symbolUses) { 353 if (contains(symUse, posLoc)) { 354 locations.emplace_back(uri, sourceMgr, op.loc); 355 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 356 } 357 } 358 } 359 360 // Check all definitions related to blocks. 361 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 362 if (containsPosition(block.definition)) 363 return; 364 for (const AsmParserState::SMDefinition &arg : block.arguments) 365 if (containsPosition(arg)) 366 return; 367 } 368 } 369 370 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri, 371 const lsp::Position &pos, 372 std::vector<lsp::Location> &references) { 373 // Functor used to append all of the definitions/uses of the given SM 374 // definition to the reference list. 375 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) { 376 references.emplace_back(uri, sourceMgr, def.loc); 377 for (const SMRange &use : def.uses) 378 references.emplace_back(uri, sourceMgr, use); 379 }; 380 381 SMLoc posLoc = pos.getAsSMLoc(sourceMgr); 382 383 // Check all definitions related to operations. 384 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 385 if (contains(op.loc, posLoc)) { 386 for (const auto &result : op.resultGroups) 387 appendSMDef(result.definition); 388 for (const auto &symUse : op.symbolUses) 389 if (contains(symUse, posLoc)) 390 references.emplace_back(uri, sourceMgr, symUse); 391 return; 392 } 393 for (const auto &result : op.resultGroups) 394 if (isDefOrUse(result.definition, posLoc)) 395 return appendSMDef(result.definition); 396 for (const auto &symUse : op.symbolUses) { 397 if (!contains(symUse, posLoc)) 398 continue; 399 for (const auto &symUse : op.symbolUses) 400 references.emplace_back(uri, sourceMgr, symUse); 401 return; 402 } 403 } 404 405 // Check all definitions related to blocks. 406 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 407 if (isDefOrUse(block.definition, posLoc)) 408 return appendSMDef(block.definition); 409 410 for (const AsmParserState::SMDefinition &arg : block.arguments) 411 if (isDefOrUse(arg, posLoc)) 412 return appendSMDef(arg); 413 } 414 } 415 416 //===----------------------------------------------------------------------===// 417 // MLIRDocument: Hover 418 //===----------------------------------------------------------------------===// 419 420 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri, 421 const lsp::Position &hoverPos) { 422 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr); 423 SMRange hoverRange; 424 425 // Check for Hovers on operations and results. 426 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 427 // Check if the position points at this operation. 428 if (contains(op.loc, posLoc)) 429 return buildHoverForOperation(op.loc, op); 430 431 // Check if the position points at the symbol name. 432 for (auto &use : op.symbolUses) 433 if (contains(use, posLoc)) 434 return buildHoverForOperation(use, op); 435 436 // Check if the position points at a result group. 437 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) { 438 const auto &result = op.resultGroups[i]; 439 if (!isDefOrUse(result.definition, posLoc, &hoverRange)) 440 continue; 441 442 // Get the range of results covered by the over position. 443 unsigned resultStart = result.startIndex; 444 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults() 445 : op.resultGroups[i + 1].startIndex; 446 return buildHoverForOperationResult(hoverRange, op.op, resultStart, 447 resultEnd, posLoc); 448 } 449 } 450 451 // Check to see if the hover is over a block argument. 452 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 453 if (isDefOrUse(block.definition, posLoc, &hoverRange)) 454 return buildHoverForBlock(hoverRange, block); 455 456 for (const auto &arg : llvm::enumerate(block.arguments)) { 457 if (!isDefOrUse(arg.value(), posLoc, &hoverRange)) 458 continue; 459 460 return buildHoverForBlockArgument( 461 hoverRange, block.block->getArgument(arg.index()), block); 462 } 463 } 464 return llvm::None; 465 } 466 467 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation( 468 SMRange hoverRange, const AsmParserState::OperationDefinition &op) { 469 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 470 llvm::raw_string_ostream os(hover.contents.value); 471 472 // Add the operation name to the hover. 473 os << "\"" << op.op->getName() << "\""; 474 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op)) 475 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << ""; 476 os << "\n\n"; 477 478 os << "Generic Form:\n\n```mlir\n"; 479 480 // Temporary drop the regions of this operation so that they don't get 481 // printed in the output. This helps keeps the size of the output hover 482 // small. 483 SmallVector<std::unique_ptr<Region>> regions; 484 for (Region ®ion : op.op->getRegions()) { 485 regions.emplace_back(std::make_unique<Region>()); 486 regions.back()->takeBody(region); 487 } 488 489 op.op->print( 490 os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); 491 os << "\n```\n"; 492 493 // Move the regions back to the current operation. 494 for (Region ®ion : op.op->getRegions()) 495 region.takeBody(*regions.back()); 496 497 return hover; 498 } 499 500 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange, 501 Operation *op, 502 unsigned resultStart, 503 unsigned resultEnd, 504 SMLoc posLoc) { 505 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 506 llvm::raw_string_ostream os(hover.contents.value); 507 508 // Add the parent operation name to the hover. 509 os << "Operation: \"" << op->getName() << "\"\n\n"; 510 511 // Check to see if the location points to a specific result within the 512 // group. 513 if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) { 514 if ((resultStart + *resultNumber) < resultEnd) { 515 resultStart += *resultNumber; 516 resultEnd = resultStart + 1; 517 } 518 } 519 520 // Add the range of results and their types to the hover info. 521 if ((resultStart + 1) == resultEnd) { 522 os << "Result #" << resultStart << "\n\n" 523 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n"; 524 } else { 525 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n" 526 << "Types: "; 527 llvm::interleaveComma( 528 op->getResults().slice(resultStart, resultEnd), os, 529 [&](Value result) { os << "`" << result.getType() << "`"; }); 530 } 531 532 return hover; 533 } 534 535 lsp::Hover 536 MLIRDocument::buildHoverForBlock(SMRange hoverRange, 537 const AsmParserState::BlockDefinition &block) { 538 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 539 llvm::raw_string_ostream os(hover.contents.value); 540 541 // Print the given block to the hover output stream. 542 auto printBlockToHover = [&](Block *newBlock) { 543 if (const auto *def = asmState.getBlockDef(newBlock)) 544 printDefBlockName(os, *def); 545 else 546 printDefBlockName(os, newBlock); 547 }; 548 549 // Display the parent operation, block number, predecessors, and successors. 550 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 551 << "Block #" << getBlockNumber(block.block) << "\n\n"; 552 if (!block.block->hasNoPredecessors()) { 553 os << "Predecessors: "; 554 llvm::interleaveComma(block.block->getPredecessors(), os, 555 printBlockToHover); 556 os << "\n\n"; 557 } 558 if (!block.block->hasNoSuccessors()) { 559 os << "Successors: "; 560 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover); 561 os << "\n\n"; 562 } 563 564 return hover; 565 } 566 567 lsp::Hover MLIRDocument::buildHoverForBlockArgument( 568 SMRange hoverRange, BlockArgument arg, 569 const AsmParserState::BlockDefinition &block) { 570 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 571 llvm::raw_string_ostream os(hover.contents.value); 572 573 // Display the parent operation, block, the argument number, and the type. 574 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 575 << "Block: "; 576 printDefBlockName(os, block); 577 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n" 578 << "Type: `" << arg.getType() << "`\n\n"; 579 580 return hover; 581 } 582 583 //===----------------------------------------------------------------------===// 584 // MLIRDocument: Document Symbols 585 //===----------------------------------------------------------------------===// 586 587 void MLIRDocument::findDocumentSymbols( 588 std::vector<lsp::DocumentSymbol> &symbols) { 589 for (Operation &op : parsedIR) 590 findDocumentSymbols(&op, symbols); 591 } 592 593 void MLIRDocument::findDocumentSymbols( 594 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) { 595 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols; 596 597 // Check for the source information of this operation. 598 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) { 599 // If this operation defines a symbol, record it. 600 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) { 601 symbols.emplace_back(symbol.getName(), 602 isa<FunctionOpInterface>(op) 603 ? lsp::SymbolKind::Function 604 : lsp::SymbolKind::Class, 605 lsp::Range(sourceMgr, def->scopeLoc), 606 lsp::Range(sourceMgr, def->loc)); 607 childSymbols = &symbols.back().children; 608 609 } else if (op->hasTrait<OpTrait::SymbolTable>()) { 610 // Otherwise, if this is a symbol table push an anonymous document symbol. 611 symbols.emplace_back("<" + op->getName().getStringRef() + ">", 612 lsp::SymbolKind::Namespace, 613 lsp::Range(sourceMgr, def->scopeLoc), 614 lsp::Range(sourceMgr, def->loc)); 615 childSymbols = &symbols.back().children; 616 } 617 } 618 619 // Recurse into the regions of this operation. 620 if (!op->getNumRegions()) 621 return; 622 for (Region ®ion : op->getRegions()) 623 for (Operation &childOp : region.getOps()) 624 findDocumentSymbols(&childOp, *childSymbols); 625 } 626 627 //===----------------------------------------------------------------------===// 628 // MLIRDocument: Code Completion 629 //===----------------------------------------------------------------------===// 630 631 namespace { 632 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { 633 public: 634 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList, 635 MLIRContext *ctx) 636 : AsmParserCodeCompleteContext(completeLoc), 637 completionList(completionList), ctx(ctx) {} 638 639 /// Signal code completion for a dialect name. 640 void completeDialectName() final { 641 for (StringRef dialect : ctx->getAvailableDialects()) { 642 lsp::CompletionItem item(dialect, lsp::CompletionItemKind::Module); 643 item.sortText = "2"; 644 item.detail = "dialect"; 645 completionList.items.emplace_back(item); 646 } 647 } 648 649 /// Signal code completion for an operation name within the given dialect. 650 void completeOperationName(StringRef dialectName) final { 651 Dialect *dialect = ctx->getOrLoadDialect(dialectName); 652 if (!dialect) 653 return; 654 655 for (const auto &op : ctx->getRegisteredOperations()) { 656 if (&op.getDialect() != dialect) 657 continue; 658 659 lsp::CompletionItem item( 660 op.getStringRef().drop_front(dialectName.size() + 1), 661 lsp::CompletionItemKind::Field); 662 item.sortText = "1"; 663 item.detail = "operation"; 664 completionList.items.emplace_back(item); 665 } 666 } 667 668 /// Append the given SSA value as a code completion result for SSA value 669 /// completions. 670 void appendSSAValueCompletion(StringRef name, std::string typeData) final { 671 // Check if we need to insert the `%` or not. 672 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%'; 673 674 lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable); 675 if (stripPrefix) 676 item.insertText = name.drop_front(1).str(); 677 item.detail = std::move(typeData); 678 completionList.items.emplace_back(item); 679 } 680 681 /// Append the given block as a code completion result for block name 682 /// completions. 683 void appendBlockCompletion(StringRef name) final { 684 // Check if we need to insert the `^` or not. 685 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^'; 686 687 lsp::CompletionItem item(name, lsp::CompletionItemKind::Field); 688 if (stripPrefix) 689 item.insertText = name.drop_front(1).str(); 690 completionList.items.emplace_back(item); 691 } 692 693 private: 694 lsp::CompletionList &completionList; 695 MLIRContext *ctx; 696 }; 697 } // namespace 698 699 lsp::CompletionList 700 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri, 701 const lsp::Position &completePos, 702 const DialectRegistry ®istry) { 703 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr); 704 if (!posLoc.isValid()) 705 return lsp::CompletionList(); 706 707 // To perform code completion, we run another parse of the module with the 708 // code completion context provided. 709 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED); 710 tmpContext.allowUnregisteredDialects(); 711 lsp::CompletionList completionList; 712 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList, 713 &tmpContext); 714 715 Block tmpIR; 716 AsmParserState tmpState; 717 (void)parseSourceFile(sourceMgr, &tmpIR, &tmpContext, 718 /*sourceFileLoc=*/nullptr, &tmpState, 719 &lspCompleteContext); 720 return completionList; 721 } 722 723 //===----------------------------------------------------------------------===// 724 // MLIRTextFileChunk 725 //===----------------------------------------------------------------------===// 726 727 namespace { 728 /// This class represents a single chunk of an MLIR text file. 729 struct MLIRTextFileChunk { 730 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset, 731 const lsp::URIForFile &uri, StringRef contents, 732 std::vector<lsp::Diagnostic> &diagnostics) 733 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {} 734 735 /// Adjust the line number of the given range to anchor at the beginning of 736 /// the file, instead of the beginning of this chunk. 737 void adjustLocForChunkOffset(lsp::Range &range) { 738 adjustLocForChunkOffset(range.start); 739 adjustLocForChunkOffset(range.end); 740 } 741 /// Adjust the line number of the given position to anchor at the beginning of 742 /// the file, instead of the beginning of this chunk. 743 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } 744 745 /// The line offset of this chunk from the beginning of the file. 746 uint64_t lineOffset; 747 /// The document referred to by this chunk. 748 MLIRDocument document; 749 }; 750 } // namespace 751 752 //===----------------------------------------------------------------------===// 753 // MLIRTextFile 754 //===----------------------------------------------------------------------===// 755 756 namespace { 757 /// This class represents a text file containing one or more MLIR documents. 758 class MLIRTextFile { 759 public: 760 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 761 int64_t version, DialectRegistry ®istry, 762 std::vector<lsp::Diagnostic> &diagnostics); 763 764 /// Return the current version of this text file. 765 int64_t getVersion() const { return version; } 766 767 //===--------------------------------------------------------------------===// 768 // LSP Queries 769 //===--------------------------------------------------------------------===// 770 771 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, 772 std::vector<lsp::Location> &locations); 773 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, 774 std::vector<lsp::Location> &references); 775 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 776 lsp::Position hoverPos); 777 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 778 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, 779 lsp::Position completePos); 780 781 private: 782 /// Find the MLIR document that contains the given position, and update the 783 /// position to be anchored at the start of the found chunk instead of the 784 /// beginning of the file. 785 MLIRTextFileChunk &getChunkFor(lsp::Position &pos); 786 787 /// The context used to hold the state contained by the parsed document. 788 MLIRContext context; 789 790 /// The full string contents of the file. 791 std::string contents; 792 793 /// The version of this file. 794 int64_t version; 795 796 /// The number of lines in the file. 797 int64_t totalNumLines = 0; 798 799 /// The chunks of this file. The order of these chunks is the order in which 800 /// they appear in the text file. 801 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks; 802 }; 803 } // namespace 804 805 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 806 int64_t version, DialectRegistry ®istry, 807 std::vector<lsp::Diagnostic> &diagnostics) 808 : context(registry, MLIRContext::Threading::DISABLED), 809 contents(fileContents.str()), version(version) { 810 context.allowUnregisteredDialects(); 811 812 // Split the file into separate MLIR documents. 813 // TODO: Find a way to share the split file marker with other tools. We don't 814 // want to use `splitAndProcessBuffer` here, but we do want to make sure this 815 // marker doesn't go out of sync. 816 SmallVector<StringRef, 8> subContents; 817 StringRef(contents).split(subContents, "// -----"); 818 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>( 819 context, /*lineOffset=*/0, uri, subContents.front(), diagnostics)); 820 821 uint64_t lineOffset = subContents.front().count('\n'); 822 for (StringRef docContents : llvm::drop_begin(subContents)) { 823 unsigned currentNumDiags = diagnostics.size(); 824 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri, 825 docContents, diagnostics); 826 lineOffset += docContents.count('\n'); 827 828 // Adjust locations used in diagnostics to account for the offset from the 829 // beginning of the file. 830 for (lsp::Diagnostic &diag : 831 llvm::drop_begin(diagnostics, currentNumDiags)) { 832 chunk->adjustLocForChunkOffset(diag.range); 833 834 if (!diag.relatedInformation) 835 continue; 836 for (auto &it : *diag.relatedInformation) 837 if (it.location.uri == uri) 838 chunk->adjustLocForChunkOffset(it.location.range); 839 } 840 chunks.emplace_back(std::move(chunk)); 841 } 842 totalNumLines = lineOffset; 843 } 844 845 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri, 846 lsp::Position defPos, 847 std::vector<lsp::Location> &locations) { 848 MLIRTextFileChunk &chunk = getChunkFor(defPos); 849 chunk.document.getLocationsOf(uri, defPos, locations); 850 851 // Adjust any locations within this file for the offset of this chunk. 852 if (chunk.lineOffset == 0) 853 return; 854 for (lsp::Location &loc : locations) 855 if (loc.uri == uri) 856 chunk.adjustLocForChunkOffset(loc.range); 857 } 858 859 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri, 860 lsp::Position pos, 861 std::vector<lsp::Location> &references) { 862 MLIRTextFileChunk &chunk = getChunkFor(pos); 863 chunk.document.findReferencesOf(uri, pos, references); 864 865 // Adjust any locations within this file for the offset of this chunk. 866 if (chunk.lineOffset == 0) 867 return; 868 for (lsp::Location &loc : references) 869 if (loc.uri == uri) 870 chunk.adjustLocForChunkOffset(loc.range); 871 } 872 873 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri, 874 lsp::Position hoverPos) { 875 MLIRTextFileChunk &chunk = getChunkFor(hoverPos); 876 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos); 877 878 // Adjust any locations within this file for the offset of this chunk. 879 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) 880 chunk.adjustLocForChunkOffset(*hoverInfo->range); 881 return hoverInfo; 882 } 883 884 void MLIRTextFile::findDocumentSymbols( 885 std::vector<lsp::DocumentSymbol> &symbols) { 886 if (chunks.size() == 1) 887 return chunks.front()->document.findDocumentSymbols(symbols); 888 889 // If there are multiple chunks in this file, we create top-level symbols for 890 // each chunk. 891 for (unsigned i = 0, e = chunks.size(); i < e; ++i) { 892 MLIRTextFileChunk &chunk = *chunks[i]; 893 lsp::Position startPos(chunk.lineOffset); 894 lsp::Position endPos((i == e - 1) ? totalNumLines - 1 895 : chunks[i + 1]->lineOffset); 896 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">", 897 lsp::SymbolKind::Namespace, 898 /*range=*/lsp::Range(startPos, endPos), 899 /*selectionRange=*/lsp::Range(startPos)); 900 chunk.document.findDocumentSymbols(symbol.children); 901 902 // Fixup the locations of document symbols within this chunk. 903 if (i != 0) { 904 SmallVector<lsp::DocumentSymbol *> symbolsToFix; 905 for (lsp::DocumentSymbol &childSymbol : symbol.children) 906 symbolsToFix.push_back(&childSymbol); 907 908 while (!symbolsToFix.empty()) { 909 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); 910 chunk.adjustLocForChunkOffset(symbol->range); 911 chunk.adjustLocForChunkOffset(symbol->selectionRange); 912 913 for (lsp::DocumentSymbol &childSymbol : symbol->children) 914 symbolsToFix.push_back(&childSymbol); 915 } 916 } 917 918 // Push the symbol for this chunk. 919 symbols.emplace_back(std::move(symbol)); 920 } 921 } 922 923 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri, 924 lsp::Position completePos) { 925 MLIRTextFileChunk &chunk = getChunkFor(completePos); 926 lsp::CompletionList completionList = chunk.document.getCodeCompletion( 927 uri, completePos, context.getDialectRegistry()); 928 929 // Adjust any completion locations. 930 for (lsp::CompletionItem &item : completionList.items) { 931 if (item.textEdit) 932 chunk.adjustLocForChunkOffset(item.textEdit->range); 933 for (lsp::TextEdit &edit : item.additionalTextEdits) 934 chunk.adjustLocForChunkOffset(edit.range); 935 } 936 return completionList; 937 } 938 939 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { 940 if (chunks.size() == 1) 941 return *chunks.front(); 942 943 // Search for the first chunk with a greater line offset, the previous chunk 944 // is the one that contains `pos`. 945 auto it = llvm::upper_bound( 946 chunks, pos, [](const lsp::Position &pos, const auto &chunk) { 947 return static_cast<uint64_t>(pos.line) < chunk->lineOffset; 948 }); 949 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it); 950 pos.line -= chunk.lineOffset; 951 return chunk; 952 } 953 954 //===----------------------------------------------------------------------===// 955 // MLIRServer::Impl 956 //===----------------------------------------------------------------------===// 957 958 struct lsp::MLIRServer::Impl { 959 Impl(DialectRegistry ®istry) : registry(registry) {} 960 961 /// The registry containing dialects that can be recognized in parsed .mlir 962 /// files. 963 DialectRegistry ®istry; 964 965 /// The files held by the server, mapped by their URI file name. 966 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files; 967 }; 968 969 //===----------------------------------------------------------------------===// 970 // MLIRServer 971 //===----------------------------------------------------------------------===// 972 973 lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) 974 : impl(std::make_unique<Impl>(registry)) {} 975 lsp::MLIRServer::~MLIRServer() = default; 976 977 void lsp::MLIRServer::addOrUpdateDocument( 978 const URIForFile &uri, StringRef contents, int64_t version, 979 std::vector<Diagnostic> &diagnostics) { 980 impl->files[uri.file()] = std::make_unique<MLIRTextFile>( 981 uri, contents, version, impl->registry, diagnostics); 982 } 983 984 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) { 985 auto it = impl->files.find(uri.file()); 986 if (it == impl->files.end()) 987 return llvm::None; 988 989 int64_t version = it->second->getVersion(); 990 impl->files.erase(it); 991 return version; 992 } 993 994 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, 995 const Position &defPos, 996 std::vector<Location> &locations) { 997 auto fileIt = impl->files.find(uri.file()); 998 if (fileIt != impl->files.end()) 999 fileIt->second->getLocationsOf(uri, defPos, locations); 1000 } 1001 1002 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, 1003 const Position &pos, 1004 std::vector<Location> &references) { 1005 auto fileIt = impl->files.find(uri.file()); 1006 if (fileIt != impl->files.end()) 1007 fileIt->second->findReferencesOf(uri, pos, references); 1008 } 1009 1010 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri, 1011 const Position &hoverPos) { 1012 auto fileIt = impl->files.find(uri.file()); 1013 if (fileIt != impl->files.end()) 1014 return fileIt->second->findHover(uri, hoverPos); 1015 return llvm::None; 1016 } 1017 1018 void lsp::MLIRServer::findDocumentSymbols( 1019 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) { 1020 auto fileIt = impl->files.find(uri.file()); 1021 if (fileIt != impl->files.end()) 1022 fileIt->second->findDocumentSymbols(symbols); 1023 } 1024 1025 lsp::CompletionList 1026 lsp::MLIRServer::getCodeCompletion(const URIForFile &uri, 1027 const Position &completePos) { 1028 auto fileIt = impl->files.find(uri.file()); 1029 if (fileIt != impl->files.end()) 1030 return fileIt->second->getCodeCompletion(uri, completePos); 1031 return CompletionList(); 1032 } 1033