1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MLIRServer.h" 10 #include "lsp/Logging.h" 11 #include "lsp/Protocol.h" 12 #include "mlir/IR/Operation.h" 13 #include "mlir/Parser.h" 14 #include "mlir/Parser/AsmParserState.h" 15 #include "llvm/Support/SourceMgr.h" 16 17 using namespace mlir; 18 19 /// Returns a language server position for the given source location. 20 static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc) { 21 std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc); 22 lsp::Position pos; 23 pos.line = lineAndCol.first - 1; 24 pos.character = lineAndCol.second - 1; 25 return pos; 26 } 27 28 /// Returns a source location from the given language server position. 29 static llvm::SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) { 30 return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), pos.line + 1, 31 pos.character); 32 } 33 34 /// Returns a language server range for the given source range. 35 static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range) { 36 return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)}; 37 } 38 39 /// Returns a language server location from the given source range. 40 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, 41 llvm::SMRange range, 42 const lsp::URIForFile &uri) { 43 return lsp::Location{uri, getRangeFromLoc(mgr, range)}; 44 } 45 46 /// Returns a language server location from the given MLIR file location. 47 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) { 48 llvm::Expected<lsp::URIForFile> sourceURI = 49 lsp::URIForFile::fromFile(loc.getFilename()); 50 if (!sourceURI) { 51 lsp::Logger::error("Failed to create URI for file `{0}`: {1}", 52 loc.getFilename(), 53 llvm::toString(sourceURI.takeError())); 54 return llvm::None; 55 } 56 57 lsp::Position position; 58 position.line = loc.getLine() - 1; 59 position.character = loc.getColumn(); 60 return lsp::Location{*sourceURI, lsp::Range(position)}; 61 } 62 63 /// Returns a language server location from the given MLIR location, or None if 64 /// one couldn't be created. `uri` is an optional additional filter that, when 65 /// present, is used to filter sub locations that do not share the same uri. 66 static Optional<lsp::Location> 67 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, 68 const lsp::URIForFile *uri = nullptr) { 69 Optional<lsp::Location> location; 70 loc->walk([&](Location nestedLoc) { 71 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 72 if (!fileLoc) 73 return WalkResult::advance(); 74 75 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 76 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { 77 location = *sourceLoc; 78 llvm::SMLoc loc = sourceMgr.FindLocForLineAndColumn( 79 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn()); 80 81 // Use range of potential identifier starting at location, else length 1 82 // range. 83 location->range.end.character += 1; 84 if (Optional<llvm::SMRange> range = 85 AsmParserState::convertIdLocToRange(loc)) { 86 auto lineCol = sourceMgr.getLineAndColumn(range->End); 87 location->range.end.character = 88 std::max(fileLoc.getColumn() + 1, lineCol.second - 1); 89 } 90 return WalkResult::interrupt(); 91 } 92 return WalkResult::advance(); 93 }); 94 return location; 95 } 96 97 /// Collect all of the locations from the given MLIR location that are not 98 /// contained within the given URI. 99 static void collectLocationsFromLoc(Location loc, 100 std::vector<lsp::Location> &locations, 101 const lsp::URIForFile &uri) { 102 SetVector<Location> visitedLocs; 103 loc->walk([&](Location nestedLoc) { 104 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 105 if (!fileLoc || !visitedLocs.insert(nestedLoc)) 106 return WalkResult::advance(); 107 108 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 109 if (sourceLoc && sourceLoc->uri != uri) 110 locations.push_back(*sourceLoc); 111 return WalkResult::advance(); 112 }); 113 } 114 115 /// Returns true if the given range contains the given source location. Note 116 /// that this has slightly different behavior than SMRange because it is 117 /// inclusive of the end location. 118 static bool contains(llvm::SMRange range, llvm::SMLoc loc) { 119 return range.Start.getPointer() <= loc.getPointer() && 120 loc.getPointer() <= range.End.getPointer(); 121 } 122 123 /// Returns true if the given location is contained by the definition or one of 124 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to 125 /// the range within `def` that the provided `loc` overlapped with. 126 static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc, 127 llvm::SMRange *overlappedRange = nullptr) { 128 // Check the main definition. 129 if (contains(def.loc, loc)) { 130 if (overlappedRange) 131 *overlappedRange = def.loc; 132 return true; 133 } 134 135 // Check the uses. 136 auto useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) { 137 return contains(range, loc); 138 }); 139 if (useIt != def.uses.end()) { 140 if (overlappedRange) 141 *overlappedRange = *useIt; 142 return true; 143 } 144 return false; 145 } 146 147 /// Given a location pointing to a result, return the result number it refers 148 /// to or None if it refers to all of the results. 149 static Optional<unsigned> getResultNumberFromLoc(llvm::SMLoc loc) { 150 // Skip all of the identifier characters. 151 auto isIdentifierChar = [](char c) { 152 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' || 153 c == '-'; 154 }; 155 const char *curPtr = loc.getPointer(); 156 while (isIdentifierChar(*curPtr)) 157 ++curPtr; 158 159 // Check to see if this location indexes into the result group, via `#`. If it 160 // doesn't, we can't extract a sub result number. 161 if (*curPtr != '#') 162 return llvm::None; 163 164 // Compute the sub result number from the remaining portion of the string. 165 const char *numberStart = ++curPtr; 166 while (llvm::isDigit(*curPtr)) 167 ++curPtr; 168 StringRef numberStr(numberStart, curPtr - numberStart); 169 unsigned resultNumber = 0; 170 return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>() 171 : resultNumber; 172 } 173 174 /// Given a source location range, return the text covered by the given range. 175 /// If the range is invalid, returns None. 176 static Optional<StringRef> getTextFromRange(llvm::SMRange range) { 177 if (!range.isValid()) 178 return None; 179 const char *startPtr = range.Start.getPointer(); 180 return StringRef(startPtr, range.End.getPointer() - startPtr); 181 } 182 183 /// Given a block, return its position in its parent region. 184 static unsigned getBlockNumber(Block *block) { 185 return std::distance(block->getParent()->begin(), block->getIterator()); 186 } 187 188 /// Given a block and source location, print the source name of the block to the 189 /// given output stream. 190 static void printDefBlockName(raw_ostream &os, Block *block, 191 llvm::SMRange loc = {}) { 192 // Try to extract a name from the source location. 193 Optional<StringRef> text = getTextFromRange(loc); 194 if (text && text->startswith("^")) { 195 os << *text; 196 return; 197 } 198 199 // Otherwise, we don't have a name so print the block number. 200 os << "<Block #" << getBlockNumber(block) << ">"; 201 } 202 static void printDefBlockName(raw_ostream &os, 203 const AsmParserState::BlockDefinition &def) { 204 printDefBlockName(os, def.block, def.definition.loc); 205 } 206 207 /// Convert the given MLIR diagnostic to the LSP form. 208 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, 209 Diagnostic &diag, 210 const lsp::URIForFile &uri) { 211 lsp::Diagnostic lspDiag; 212 lspDiag.source = "mlir"; 213 214 // Note: Right now all of the diagnostics are treated as parser issues, but 215 // some are parser and some are verifier. 216 lspDiag.category = "Parse Error"; 217 218 // Try to grab a file location for this diagnostic. 219 // TODO: For simplicity, we just grab the first one. It may be likely that we 220 // will need a more interesting heuristic here.' 221 Optional<lsp::Location> lspLocation = 222 getLocationFromLoc(sourceMgr, diag.getLocation(), &uri); 223 if (lspLocation) 224 lspDiag.range = lspLocation->range; 225 226 // Convert the severity for the diagnostic. 227 switch (diag.getSeverity()) { 228 case DiagnosticSeverity::Note: 229 llvm_unreachable("expected notes to be handled separately"); 230 case DiagnosticSeverity::Warning: 231 lspDiag.severity = lsp::DiagnosticSeverity::Warning; 232 break; 233 case DiagnosticSeverity::Error: 234 lspDiag.severity = lsp::DiagnosticSeverity::Error; 235 break; 236 case DiagnosticSeverity::Remark: 237 lspDiag.severity = lsp::DiagnosticSeverity::Information; 238 break; 239 } 240 lspDiag.message = diag.str(); 241 242 // Attach any notes to the main diagnostic as related information. 243 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; 244 for (Diagnostic ¬e : diag.getNotes()) { 245 lsp::Location noteLoc; 246 if (Optional<lsp::Location> loc = 247 getLocationFromLoc(sourceMgr, note.getLocation())) 248 noteLoc = *loc; 249 else 250 noteLoc.uri = uri; 251 relatedDiags.emplace_back(noteLoc, note.str()); 252 } 253 if (!relatedDiags.empty()) 254 lspDiag.relatedInformation = std::move(relatedDiags); 255 256 return lspDiag; 257 } 258 259 //===----------------------------------------------------------------------===// 260 // MLIRDocument 261 //===----------------------------------------------------------------------===// 262 263 namespace { 264 /// This class represents all of the information pertaining to a specific MLIR 265 /// document. 266 struct MLIRDocument { 267 MLIRDocument(const lsp::URIForFile &uri, StringRef contents, 268 DialectRegistry ®istry, 269 std::vector<lsp::Diagnostic> &diagnostics); 270 MLIRDocument(const MLIRDocument &) = delete; 271 MLIRDocument &operator=(const MLIRDocument &) = delete; 272 273 //===--------------------------------------------------------------------===// 274 // Definitions and References 275 //===--------------------------------------------------------------------===// 276 277 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, 278 std::vector<lsp::Location> &locations); 279 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, 280 std::vector<lsp::Location> &references); 281 282 //===--------------------------------------------------------------------===// 283 // Hover 284 //===--------------------------------------------------------------------===// 285 286 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 287 const lsp::Position &hoverPos); 288 Optional<lsp::Hover> 289 buildHoverForOperation(llvm::SMRange hoverRange, 290 const AsmParserState::OperationDefinition &op); 291 lsp::Hover buildHoverForOperationResult(llvm::SMRange hoverRange, 292 Operation *op, unsigned resultStart, 293 unsigned resultEnd, 294 llvm::SMLoc posLoc); 295 lsp::Hover buildHoverForBlock(llvm::SMRange hoverRange, 296 const AsmParserState::BlockDefinition &block); 297 lsp::Hover 298 buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg, 299 const AsmParserState::BlockDefinition &block); 300 301 //===--------------------------------------------------------------------===// 302 // Document Symbols 303 //===--------------------------------------------------------------------===// 304 305 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 306 void findDocumentSymbols(Operation *op, 307 std::vector<lsp::DocumentSymbol> &symbols); 308 309 //===--------------------------------------------------------------------===// 310 // Fields 311 //===--------------------------------------------------------------------===// 312 313 /// The context used to hold the state contained by the parsed document. 314 MLIRContext context; 315 316 /// The high level parser state used to find definitions and references within 317 /// the source file. 318 AsmParserState asmState; 319 320 /// The container for the IR parsed from the input file. 321 Block parsedIR; 322 323 /// The source manager containing the contents of the input file. 324 llvm::SourceMgr sourceMgr; 325 }; 326 } // namespace 327 328 MLIRDocument::MLIRDocument(const lsp::URIForFile &uri, StringRef contents, 329 DialectRegistry ®istry, 330 std::vector<lsp::Diagnostic> &diagnostics) 331 : context(registry) { 332 context.allowUnregisteredDialects(); 333 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { 334 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri)); 335 }); 336 337 // Try to parsed the given IR string. 338 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); 339 if (!memBuffer) { 340 lsp::Logger::error("Failed to create memory buffer for file", uri.file()); 341 return; 342 } 343 344 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), llvm::SMLoc()); 345 if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, 346 &asmState))) { 347 // If parsing failed, clear out any of the current state. 348 parsedIR.clear(); 349 asmState = AsmParserState(); 350 return; 351 } 352 } 353 354 //===----------------------------------------------------------------------===// 355 // MLIRDocument: Definitions and References 356 //===----------------------------------------------------------------------===// 357 358 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, 359 const lsp::Position &defPos, 360 std::vector<lsp::Location> &locations) { 361 llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, defPos); 362 363 // Functor used to check if an SM definition contains the position. 364 auto containsPosition = [&](const AsmParserState::SMDefinition &def) { 365 if (!isDefOrUse(def, posLoc)) 366 return false; 367 locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); 368 return true; 369 }; 370 371 // Check all definitions related to operations. 372 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 373 if (contains(op.loc, posLoc)) 374 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 375 for (const auto &result : op.resultGroups) 376 if (containsPosition(result.second)) 377 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 378 for (const auto &symUse : op.symbolUses) { 379 if (contains(symUse, posLoc)) { 380 locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri)); 381 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 382 } 383 } 384 } 385 386 // Check all definitions related to blocks. 387 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 388 if (containsPosition(block.definition)) 389 return; 390 for (const AsmParserState::SMDefinition &arg : block.arguments) 391 if (containsPosition(arg)) 392 return; 393 } 394 } 395 396 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri, 397 const lsp::Position &pos, 398 std::vector<lsp::Location> &references) { 399 // Functor used to append all of the definitions/uses of the given SM 400 // definition to the reference list. 401 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) { 402 references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); 403 for (const llvm::SMRange &use : def.uses) 404 references.push_back(getLocationFromLoc(sourceMgr, use, uri)); 405 }; 406 407 llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, pos); 408 409 // Check all definitions related to operations. 410 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 411 if (contains(op.loc, posLoc)) { 412 for (const auto &result : op.resultGroups) 413 appendSMDef(result.second); 414 for (const auto &symUse : op.symbolUses) 415 if (contains(symUse, posLoc)) 416 references.push_back(getLocationFromLoc(sourceMgr, symUse, uri)); 417 return; 418 } 419 for (const auto &result : op.resultGroups) 420 if (isDefOrUse(result.second, posLoc)) 421 return appendSMDef(result.second); 422 for (const auto &symUse : op.symbolUses) { 423 if (!contains(symUse, posLoc)) 424 continue; 425 for (const auto &symUse : op.symbolUses) 426 references.push_back(getLocationFromLoc(sourceMgr, symUse, uri)); 427 return; 428 } 429 } 430 431 // Check all definitions related to blocks. 432 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 433 if (isDefOrUse(block.definition, posLoc)) 434 return appendSMDef(block.definition); 435 436 for (const AsmParserState::SMDefinition &arg : block.arguments) 437 if (isDefOrUse(arg, posLoc)) 438 return appendSMDef(arg); 439 } 440 } 441 442 //===----------------------------------------------------------------------===// 443 // MLIRDocument: Hover 444 //===----------------------------------------------------------------------===// 445 446 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri, 447 const lsp::Position &hoverPos) { 448 llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos); 449 llvm::SMRange hoverRange; 450 451 // Check for Hovers on operations and results. 452 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 453 // Check if the position points at this operation. 454 if (contains(op.loc, posLoc)) 455 return buildHoverForOperation(op.loc, op); 456 457 // Check if the position points at the symbol name. 458 for (auto &use : op.symbolUses) 459 if (contains(use, posLoc)) 460 return buildHoverForOperation(use, op); 461 462 // Check if the position points at a result group. 463 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) { 464 const auto &result = op.resultGroups[i]; 465 if (!isDefOrUse(result.second, posLoc, &hoverRange)) 466 continue; 467 468 // Get the range of results covered by the over position. 469 unsigned resultStart = result.first; 470 unsigned resultEnd = 471 (i == e - 1) ? op.op->getNumResults() : op.resultGroups[i + 1].first; 472 return buildHoverForOperationResult(hoverRange, op.op, resultStart, 473 resultEnd, posLoc); 474 } 475 } 476 477 // Check to see if the hover is over a block argument. 478 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 479 if (isDefOrUse(block.definition, posLoc, &hoverRange)) 480 return buildHoverForBlock(hoverRange, block); 481 482 for (const auto &arg : llvm::enumerate(block.arguments)) { 483 if (!isDefOrUse(arg.value(), posLoc, &hoverRange)) 484 continue; 485 486 return buildHoverForBlockArgument( 487 hoverRange, block.block->getArgument(arg.index()), block); 488 } 489 } 490 return llvm::None; 491 } 492 493 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation( 494 llvm::SMRange hoverRange, const AsmParserState::OperationDefinition &op) { 495 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 496 llvm::raw_string_ostream os(hover.contents.value); 497 498 // Add the operation name to the hover. 499 os << "\"" << op.op->getName() << "\""; 500 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op)) 501 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << ""; 502 os << "\n\n"; 503 504 os << "Generic Form:\n\n```mlir\n"; 505 506 // Temporary drop the regions of this operation so that they don't get 507 // printed in the output. This helps keeps the size of the output hover 508 // small. 509 SmallVector<std::unique_ptr<Region>> regions; 510 for (Region ®ion : op.op->getRegions()) { 511 regions.emplace_back(std::make_unique<Region>()); 512 regions.back()->takeBody(region); 513 } 514 515 op.op->print( 516 os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); 517 os << "\n```\n"; 518 519 // Move the regions back to the current operation. 520 for (Region ®ion : op.op->getRegions()) 521 region.takeBody(*regions.back()); 522 523 return hover; 524 } 525 526 lsp::Hover MLIRDocument::buildHoverForOperationResult(llvm::SMRange hoverRange, 527 Operation *op, 528 unsigned resultStart, 529 unsigned resultEnd, 530 llvm::SMLoc posLoc) { 531 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 532 llvm::raw_string_ostream os(hover.contents.value); 533 534 // Add the parent operation name to the hover. 535 os << "Operation: \"" << op->getName() << "\"\n\n"; 536 537 // Check to see if the location points to a specific result within the 538 // group. 539 if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) { 540 if ((resultStart + *resultNumber) < resultEnd) { 541 resultStart += *resultNumber; 542 resultEnd = resultStart + 1; 543 } 544 } 545 546 // Add the range of results and their types to the hover info. 547 if ((resultStart + 1) == resultEnd) { 548 os << "Result #" << resultStart << "\n\n" 549 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n"; 550 } else { 551 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n" 552 << "Types: "; 553 llvm::interleaveComma( 554 op->getResults().slice(resultStart, resultEnd), os, 555 [&](Value result) { os << "`" << result.getType() << "`"; }); 556 } 557 558 return hover; 559 } 560 561 lsp::Hover 562 MLIRDocument::buildHoverForBlock(llvm::SMRange hoverRange, 563 const AsmParserState::BlockDefinition &block) { 564 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 565 llvm::raw_string_ostream os(hover.contents.value); 566 567 // Print the given block to the hover output stream. 568 auto printBlockToHover = [&](Block *newBlock) { 569 if (const auto *def = asmState.getBlockDef(newBlock)) 570 printDefBlockName(os, *def); 571 else 572 printDefBlockName(os, newBlock); 573 }; 574 575 // Display the parent operation, block number, predecessors, and successors. 576 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 577 << "Block #" << getBlockNumber(block.block) << "\n\n"; 578 if (!block.block->hasNoPredecessors()) { 579 os << "Predecessors: "; 580 llvm::interleaveComma(block.block->getPredecessors(), os, 581 printBlockToHover); 582 os << "\n\n"; 583 } 584 if (!block.block->hasNoSuccessors()) { 585 os << "Successors: "; 586 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover); 587 os << "\n\n"; 588 } 589 590 return hover; 591 } 592 593 lsp::Hover MLIRDocument::buildHoverForBlockArgument( 594 llvm::SMRange hoverRange, BlockArgument arg, 595 const AsmParserState::BlockDefinition &block) { 596 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 597 llvm::raw_string_ostream os(hover.contents.value); 598 599 // Display the parent operation, block, the argument number, and the type. 600 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 601 << "Block: "; 602 printDefBlockName(os, block); 603 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n" 604 << "Type: `" << arg.getType() << "`\n\n"; 605 606 return hover; 607 } 608 609 //===----------------------------------------------------------------------===// 610 // MLIRDocument: Document Symbols 611 //===----------------------------------------------------------------------===// 612 613 void MLIRDocument::findDocumentSymbols( 614 std::vector<lsp::DocumentSymbol> &symbols) { 615 for (Operation &op : parsedIR) 616 findDocumentSymbols(&op, symbols); 617 } 618 619 void MLIRDocument::findDocumentSymbols( 620 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) { 621 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols; 622 623 // Check for the source information of this operation. 624 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) { 625 // If this operation defines a symbol, record it. 626 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) { 627 symbols.emplace_back(symbol.getName(), 628 op->hasTrait<OpTrait::FunctionLike>() 629 ? lsp::SymbolKind::Function 630 : lsp::SymbolKind::Class, 631 getRangeFromLoc(sourceMgr, def->scopeLoc), 632 getRangeFromLoc(sourceMgr, def->loc)); 633 childSymbols = &symbols.back().children; 634 635 } else if (op->hasTrait<OpTrait::SymbolTable>()) { 636 // Otherwise, if this is a symbol table push an anonymous document symbol. 637 symbols.emplace_back("<" + op->getName().getStringRef() + ">", 638 lsp::SymbolKind::Namespace, 639 getRangeFromLoc(sourceMgr, def->scopeLoc), 640 getRangeFromLoc(sourceMgr, def->loc)); 641 childSymbols = &symbols.back().children; 642 } 643 } 644 645 // Recurse into the regions of this operation. 646 if (!op->getNumRegions()) 647 return; 648 for (Region ®ion : op->getRegions()) 649 for (Operation &childOp : region.getOps()) 650 findDocumentSymbols(&childOp, *childSymbols); 651 } 652 653 //===----------------------------------------------------------------------===// 654 // MLIRTextFileChunk 655 //===----------------------------------------------------------------------===// 656 657 namespace { 658 /// This class represents a single chunk of an MLIR text file. 659 struct MLIRTextFileChunk { 660 MLIRTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri, 661 StringRef contents, DialectRegistry ®istry, 662 std::vector<lsp::Diagnostic> &diagnostics) 663 : lineOffset(lineOffset), document(uri, contents, registry, diagnostics) { 664 } 665 666 /// Adjust the line number of the given range to anchor at the beginning of 667 /// the file, instead of the beginning of this chunk. 668 void adjustLocForChunkOffset(lsp::Range &range) { 669 adjustLocForChunkOffset(range.start); 670 adjustLocForChunkOffset(range.end); 671 } 672 /// Adjust the line number of the given position to anchor at the beginning of 673 /// the file, instead of the beginning of this chunk. 674 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } 675 676 /// The line offset of this chunk from the beginning of the file. 677 uint64_t lineOffset; 678 /// The document referred to by this chunk. 679 MLIRDocument document; 680 }; 681 } // namespace 682 683 //===----------------------------------------------------------------------===// 684 // MLIRTextFile 685 //===----------------------------------------------------------------------===// 686 687 namespace { 688 /// This class represents a text file containing one or more MLIR documents. 689 class MLIRTextFile { 690 public: 691 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 692 int64_t version, DialectRegistry ®istry, 693 std::vector<lsp::Diagnostic> &diagnostics); 694 695 /// Return the current version of this text file. 696 int64_t getVersion() const { return version; } 697 698 //===--------------------------------------------------------------------===// 699 // LSP Queries 700 //===--------------------------------------------------------------------===// 701 702 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, 703 std::vector<lsp::Location> &locations); 704 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, 705 std::vector<lsp::Location> &references); 706 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 707 lsp::Position hoverPos); 708 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 709 710 private: 711 /// Find the MLIR document that contains the given position, and update the 712 /// position to be anchored at the start of the found chunk instead of the 713 /// beginning of the file. 714 MLIRTextFileChunk &getChunkFor(lsp::Position &pos); 715 716 /// The full string contents of the file. 717 std::string contents; 718 719 /// The version of this file. 720 int64_t version; 721 722 /// The number of lines in the file. 723 int64_t totalNumLines; 724 725 /// The chunks of this file. The order of these chunks is the order in which 726 /// they appear in the text file. 727 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks; 728 }; 729 } // namespace 730 731 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 732 int64_t version, DialectRegistry ®istry, 733 std::vector<lsp::Diagnostic> &diagnostics) 734 : contents(fileContents.str()), version(version), totalNumLines(0) { 735 // Split the file into separate MLIR documents. 736 // TODO: Find a way to share the split file marker with other tools. We don't 737 // want to use `splitAndProcessBuffer` here, but we do want to make sure this 738 // marker doesn't go out of sync. 739 SmallVector<StringRef, 8> subContents; 740 StringRef(contents).split(subContents, "// -----"); 741 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>( 742 /*lineOffset=*/0, uri, subContents.front(), registry, diagnostics)); 743 744 uint64_t lineOffset = subContents.front().count('\n'); 745 for (StringRef docContents : llvm::drop_begin(subContents)) { 746 unsigned currentNumDiags = diagnostics.size(); 747 auto chunk = std::make_unique<MLIRTextFileChunk>( 748 lineOffset, uri, docContents, registry, diagnostics); 749 lineOffset += docContents.count('\n'); 750 751 // Adjust locations used in diagnostics to account for the offset from the 752 // beginning of the file. 753 for (lsp::Diagnostic &diag : 754 llvm::drop_begin(diagnostics, currentNumDiags)) { 755 chunk->adjustLocForChunkOffset(diag.range); 756 757 if (!diag.relatedInformation) 758 continue; 759 for (auto &it : *diag.relatedInformation) 760 if (it.location.uri == uri) 761 chunk->adjustLocForChunkOffset(it.location.range); 762 } 763 chunks.emplace_back(std::move(chunk)); 764 } 765 totalNumLines = lineOffset; 766 } 767 768 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri, 769 lsp::Position defPos, 770 std::vector<lsp::Location> &locations) { 771 MLIRTextFileChunk &chunk = getChunkFor(defPos); 772 chunk.document.getLocationsOf(uri, defPos, locations); 773 774 // Adjust any locations within this file for the offset of this chunk. 775 if (chunk.lineOffset == 0) 776 return; 777 for (lsp::Location &loc : locations) 778 if (loc.uri == uri) 779 chunk.adjustLocForChunkOffset(loc.range); 780 } 781 782 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri, 783 lsp::Position pos, 784 std::vector<lsp::Location> &references) { 785 MLIRTextFileChunk &chunk = getChunkFor(pos); 786 chunk.document.findReferencesOf(uri, pos, references); 787 788 // Adjust any locations within this file for the offset of this chunk. 789 if (chunk.lineOffset == 0) 790 return; 791 for (lsp::Location &loc : references) 792 if (loc.uri == uri) 793 chunk.adjustLocForChunkOffset(loc.range); 794 } 795 796 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri, 797 lsp::Position hoverPos) { 798 MLIRTextFileChunk &chunk = getChunkFor(hoverPos); 799 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos); 800 801 // Adjust any locations within this file for the offset of this chunk. 802 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) 803 chunk.adjustLocForChunkOffset(*hoverInfo->range); 804 return hoverInfo; 805 } 806 807 void MLIRTextFile::findDocumentSymbols( 808 std::vector<lsp::DocumentSymbol> &symbols) { 809 if (chunks.size() == 1) 810 return chunks.front()->document.findDocumentSymbols(symbols); 811 812 // If there are multiple chunks in this file, we create top-level symbols for 813 // each chunk. 814 for (unsigned i = 0, e = chunks.size(); i < e; ++i) { 815 MLIRTextFileChunk &chunk = *chunks[i]; 816 lsp::Position startPos(chunk.lineOffset); 817 lsp::Position endPos((i == e - 1) ? totalNumLines - 1 818 : chunks[i + 1]->lineOffset); 819 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">", 820 lsp::SymbolKind::Namespace, 821 /*range=*/lsp::Range(startPos, endPos), 822 /*selectionRange=*/lsp::Range(startPos)); 823 chunk.document.findDocumentSymbols(symbol.children); 824 825 // Fixup the locations of document symbols within this chunk. 826 if (i != 0) { 827 SmallVector<lsp::DocumentSymbol *> symbolsToFix; 828 for (lsp::DocumentSymbol &childSymbol : symbol.children) 829 symbolsToFix.push_back(&childSymbol); 830 831 while (!symbolsToFix.empty()) { 832 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); 833 chunk.adjustLocForChunkOffset(symbol->range); 834 chunk.adjustLocForChunkOffset(symbol->selectionRange); 835 836 for (lsp::DocumentSymbol &childSymbol : symbol->children) 837 symbolsToFix.push_back(&childSymbol); 838 } 839 } 840 841 // Push the symbol for this chunk. 842 symbols.emplace_back(std::move(symbol)); 843 } 844 } 845 846 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { 847 if (chunks.size() == 1) 848 return *chunks.front(); 849 850 // Search for the first chunk with a greater line offset, the previous chunk 851 // is the one that contains `pos`. 852 auto it = llvm::upper_bound( 853 chunks, pos, [](const lsp::Position &pos, const auto &chunk) { 854 return static_cast<uint64_t>(pos.line) < chunk->lineOffset; 855 }); 856 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it); 857 pos.line -= chunk.lineOffset; 858 return chunk; 859 } 860 861 //===----------------------------------------------------------------------===// 862 // MLIRServer::Impl 863 //===----------------------------------------------------------------------===// 864 865 struct lsp::MLIRServer::Impl { 866 Impl(DialectRegistry ®istry) : registry(registry) {} 867 868 /// The registry containing dialects that can be recognized in parsed .mlir 869 /// files. 870 DialectRegistry ®istry; 871 872 /// The files held by the server, mapped by their URI file name. 873 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files; 874 }; 875 876 //===----------------------------------------------------------------------===// 877 // MLIRServer 878 //===----------------------------------------------------------------------===// 879 880 lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) 881 : impl(std::make_unique<Impl>(registry)) {} 882 lsp::MLIRServer::~MLIRServer() {} 883 884 void lsp::MLIRServer::addOrUpdateDocument( 885 const URIForFile &uri, StringRef contents, int64_t version, 886 std::vector<Diagnostic> &diagnostics) { 887 impl->files[uri.file()] = std::make_unique<MLIRTextFile>( 888 uri, contents, version, impl->registry, diagnostics); 889 } 890 891 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) { 892 auto it = impl->files.find(uri.file()); 893 if (it == impl->files.end()) 894 return llvm::None; 895 896 int64_t version = it->second->getVersion(); 897 impl->files.erase(it); 898 return version; 899 } 900 901 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, 902 const Position &defPos, 903 std::vector<Location> &locations) { 904 auto fileIt = impl->files.find(uri.file()); 905 if (fileIt != impl->files.end()) 906 fileIt->second->getLocationsOf(uri, defPos, locations); 907 } 908 909 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, 910 const Position &pos, 911 std::vector<Location> &references) { 912 auto fileIt = impl->files.find(uri.file()); 913 if (fileIt != impl->files.end()) 914 fileIt->second->findReferencesOf(uri, pos, references); 915 } 916 917 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri, 918 const Position &hoverPos) { 919 auto fileIt = impl->files.find(uri.file()); 920 if (fileIt != impl->files.end()) 921 return fileIt->second->findHover(uri, hoverPos); 922 return llvm::None; 923 } 924 925 void lsp::MLIRServer::findDocumentSymbols( 926 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) { 927 auto fileIt = impl->files.find(uri.file()); 928 if (fileIt != impl->files.end()) 929 fileIt->second->findDocumentSymbols(symbols); 930 } 931