1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MLIRServer.h" 10 #include "lsp/Logging.h" 11 #include "lsp/Protocol.h" 12 #include "mlir/IR/Operation.h" 13 #include "mlir/Parser.h" 14 #include "mlir/Parser/AsmParserState.h" 15 #include "llvm/Support/SourceMgr.h" 16 17 using namespace mlir; 18 19 /// Returns a language server position for the given source location. 20 static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc) { 21 std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc); 22 lsp::Position pos; 23 pos.line = lineAndCol.first - 1; 24 pos.character = lineAndCol.second - 1; 25 return pos; 26 } 27 28 /// Returns a source location from the given language server position. 29 static llvm::SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) { 30 return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), pos.line + 1, 31 pos.character); 32 } 33 34 /// Returns a language server range for the given source range. 35 static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range) { 36 return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)}; 37 } 38 39 /// Returns a language server location from the given source range. 40 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, 41 llvm::SMRange range, 42 const lsp::URIForFile &uri) { 43 return lsp::Location{uri, getRangeFromLoc(mgr, range)}; 44 } 45 46 /// Returns a language server location from the given MLIR file location. 47 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) { 48 llvm::Expected<lsp::URIForFile> sourceURI = 49 lsp::URIForFile::fromFile(loc.getFilename()); 50 if (!sourceURI) { 51 lsp::Logger::error("Failed to create URI for file `{0}`: {1}", 52 loc.getFilename(), 53 llvm::toString(sourceURI.takeError())); 54 return llvm::None; 55 } 56 57 lsp::Position position; 58 position.line = loc.getLine() - 1; 59 position.character = loc.getColumn(); 60 return lsp::Location{*sourceURI, lsp::Range(position)}; 61 } 62 63 /// Returns a language server location from the given MLIR location, or None if 64 /// one couldn't be created. `uri` is an optional additional filter that, when 65 /// present, is used to filter sub locations that do not share the same uri. 66 static Optional<lsp::Location> 67 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, 68 const lsp::URIForFile *uri = nullptr) { 69 Optional<lsp::Location> location; 70 loc->walk([&](Location nestedLoc) { 71 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 72 if (!fileLoc) 73 return WalkResult::advance(); 74 75 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 76 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { 77 location = *sourceLoc; 78 llvm::SMLoc loc = sourceMgr.FindLocForLineAndColumn( 79 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn()); 80 81 // Use range of potential identifier starting at location, else length 1 82 // range. 83 location->range.end.character += 1; 84 if (Optional<llvm::SMRange> range = 85 AsmParserState::convertIdLocToRange(loc)) { 86 auto lineCol = sourceMgr.getLineAndColumn(range->End); 87 location->range.end.character = 88 std::max(fileLoc.getColumn() + 1, lineCol.second - 1); 89 } 90 return WalkResult::interrupt(); 91 } 92 return WalkResult::advance(); 93 }); 94 return location; 95 } 96 97 /// Collect all of the locations from the given MLIR location that are not 98 /// contained within the given URI. 99 static void collectLocationsFromLoc(Location loc, 100 std::vector<lsp::Location> &locations, 101 const lsp::URIForFile &uri) { 102 SetVector<Location> visitedLocs; 103 loc->walk([&](Location nestedLoc) { 104 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 105 if (!fileLoc || !visitedLocs.insert(nestedLoc)) 106 return WalkResult::advance(); 107 108 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 109 if (sourceLoc && sourceLoc->uri != uri) 110 locations.push_back(*sourceLoc); 111 return WalkResult::advance(); 112 }); 113 } 114 115 /// Returns true if the given range contains the given source location. Note 116 /// that this has slightly different behavior than SMRange because it is 117 /// inclusive of the end location. 118 static bool contains(llvm::SMRange range, llvm::SMLoc loc) { 119 return range.Start.getPointer() <= loc.getPointer() && 120 loc.getPointer() <= range.End.getPointer(); 121 } 122 123 /// Returns true if the given location is contained by the definition or one of 124 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to 125 /// the range within `def` that the provided `loc` overlapped with. 126 static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc, 127 llvm::SMRange *overlappedRange = nullptr) { 128 // Check the main definition. 129 if (contains(def.loc, loc)) { 130 if (overlappedRange) 131 *overlappedRange = def.loc; 132 return true; 133 } 134 135 // Check the uses. 136 const auto *useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) { 137 return contains(range, loc); 138 }); 139 if (useIt != def.uses.end()) { 140 if (overlappedRange) 141 *overlappedRange = *useIt; 142 return true; 143 } 144 return false; 145 } 146 147 /// Given a location pointing to a result, return the result number it refers 148 /// to or None if it refers to all of the results. 149 static Optional<unsigned> getResultNumberFromLoc(llvm::SMLoc loc) { 150 // Skip all of the identifier characters. 151 auto isIdentifierChar = [](char c) { 152 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' || 153 c == '-'; 154 }; 155 const char *curPtr = loc.getPointer(); 156 while (isIdentifierChar(*curPtr)) 157 ++curPtr; 158 159 // Check to see if this location indexes into the result group, via `#`. If it 160 // doesn't, we can't extract a sub result number. 161 if (*curPtr != '#') 162 return llvm::None; 163 164 // Compute the sub result number from the remaining portion of the string. 165 const char *numberStart = ++curPtr; 166 while (llvm::isDigit(*curPtr)) 167 ++curPtr; 168 StringRef numberStr(numberStart, curPtr - numberStart); 169 unsigned resultNumber = 0; 170 return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>() 171 : resultNumber; 172 } 173 174 /// Given a source location range, return the text covered by the given range. 175 /// If the range is invalid, returns None. 176 static Optional<StringRef> getTextFromRange(llvm::SMRange range) { 177 if (!range.isValid()) 178 return None; 179 const char *startPtr = range.Start.getPointer(); 180 return StringRef(startPtr, range.End.getPointer() - startPtr); 181 } 182 183 /// Given a block, return its position in its parent region. 184 static unsigned getBlockNumber(Block *block) { 185 return std::distance(block->getParent()->begin(), block->getIterator()); 186 } 187 188 /// Given a block and source location, print the source name of the block to the 189 /// given output stream. 190 static void printDefBlockName(raw_ostream &os, Block *block, 191 llvm::SMRange loc = {}) { 192 // Try to extract a name from the source location. 193 Optional<StringRef> text = getTextFromRange(loc); 194 if (text && text->startswith("^")) { 195 os << *text; 196 return; 197 } 198 199 // Otherwise, we don't have a name so print the block number. 200 os << "<Block #" << getBlockNumber(block) << ">"; 201 } 202 static void printDefBlockName(raw_ostream &os, 203 const AsmParserState::BlockDefinition &def) { 204 printDefBlockName(os, def.block, def.definition.loc); 205 } 206 207 /// Convert the given MLIR diagnostic to the LSP form. 208 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, 209 Diagnostic &diag, 210 const lsp::URIForFile &uri) { 211 lsp::Diagnostic lspDiag; 212 lspDiag.source = "mlir"; 213 214 // Note: Right now all of the diagnostics are treated as parser issues, but 215 // some are parser and some are verifier. 216 lspDiag.category = "Parse Error"; 217 218 // Try to grab a file location for this diagnostic. 219 // TODO: For simplicity, we just grab the first one. It may be likely that we 220 // will need a more interesting heuristic here.' 221 Optional<lsp::Location> lspLocation = 222 getLocationFromLoc(sourceMgr, diag.getLocation(), &uri); 223 if (lspLocation) 224 lspDiag.range = lspLocation->range; 225 226 // Convert the severity for the diagnostic. 227 switch (diag.getSeverity()) { 228 case DiagnosticSeverity::Note: 229 llvm_unreachable("expected notes to be handled separately"); 230 case DiagnosticSeverity::Warning: 231 lspDiag.severity = lsp::DiagnosticSeverity::Warning; 232 break; 233 case DiagnosticSeverity::Error: 234 lspDiag.severity = lsp::DiagnosticSeverity::Error; 235 break; 236 case DiagnosticSeverity::Remark: 237 lspDiag.severity = lsp::DiagnosticSeverity::Information; 238 break; 239 } 240 lspDiag.message = diag.str(); 241 242 // Attach any notes to the main diagnostic as related information. 243 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; 244 for (Diagnostic ¬e : diag.getNotes()) { 245 lsp::Location noteLoc; 246 if (Optional<lsp::Location> loc = 247 getLocationFromLoc(sourceMgr, note.getLocation())) 248 noteLoc = *loc; 249 else 250 noteLoc.uri = uri; 251 relatedDiags.emplace_back(noteLoc, note.str()); 252 } 253 if (!relatedDiags.empty()) 254 lspDiag.relatedInformation = std::move(relatedDiags); 255 256 return lspDiag; 257 } 258 259 //===----------------------------------------------------------------------===// 260 // MLIRDocument 261 //===----------------------------------------------------------------------===// 262 263 namespace { 264 /// This class represents all of the information pertaining to a specific MLIR 265 /// document. 266 struct MLIRDocument { 267 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 268 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics); 269 MLIRDocument(const MLIRDocument &) = delete; 270 MLIRDocument &operator=(const MLIRDocument &) = delete; 271 272 //===--------------------------------------------------------------------===// 273 // Definitions and References 274 //===--------------------------------------------------------------------===// 275 276 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, 277 std::vector<lsp::Location> &locations); 278 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, 279 std::vector<lsp::Location> &references); 280 281 //===--------------------------------------------------------------------===// 282 // Hover 283 //===--------------------------------------------------------------------===// 284 285 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 286 const lsp::Position &hoverPos); 287 Optional<lsp::Hover> 288 buildHoverForOperation(llvm::SMRange hoverRange, 289 const AsmParserState::OperationDefinition &op); 290 lsp::Hover buildHoverForOperationResult(llvm::SMRange hoverRange, 291 Operation *op, unsigned resultStart, 292 unsigned resultEnd, 293 llvm::SMLoc posLoc); 294 lsp::Hover buildHoverForBlock(llvm::SMRange hoverRange, 295 const AsmParserState::BlockDefinition &block); 296 lsp::Hover 297 buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg, 298 const AsmParserState::BlockDefinition &block); 299 300 //===--------------------------------------------------------------------===// 301 // Document Symbols 302 //===--------------------------------------------------------------------===// 303 304 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 305 void findDocumentSymbols(Operation *op, 306 std::vector<lsp::DocumentSymbol> &symbols); 307 308 //===--------------------------------------------------------------------===// 309 // Fields 310 //===--------------------------------------------------------------------===// 311 312 /// The high level parser state used to find definitions and references within 313 /// the source file. 314 AsmParserState asmState; 315 316 /// The container for the IR parsed from the input file. 317 Block parsedIR; 318 319 /// The source manager containing the contents of the input file. 320 llvm::SourceMgr sourceMgr; 321 }; 322 } // namespace 323 324 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 325 StringRef contents, 326 std::vector<lsp::Diagnostic> &diagnostics) { 327 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { 328 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri)); 329 }); 330 331 // Try to parsed the given IR string. 332 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); 333 if (!memBuffer) { 334 lsp::Logger::error("Failed to create memory buffer for file", uri.file()); 335 return; 336 } 337 338 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), llvm::SMLoc()); 339 if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, 340 &asmState))) { 341 // If parsing failed, clear out any of the current state. 342 parsedIR.clear(); 343 asmState = AsmParserState(); 344 return; 345 } 346 } 347 348 //===----------------------------------------------------------------------===// 349 // MLIRDocument: Definitions and References 350 //===----------------------------------------------------------------------===// 351 352 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, 353 const lsp::Position &defPos, 354 std::vector<lsp::Location> &locations) { 355 llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, defPos); 356 357 // Functor used to check if an SM definition contains the position. 358 auto containsPosition = [&](const AsmParserState::SMDefinition &def) { 359 if (!isDefOrUse(def, posLoc)) 360 return false; 361 locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); 362 return true; 363 }; 364 365 // Check all definitions related to operations. 366 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 367 if (contains(op.loc, posLoc)) 368 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 369 for (const auto &result : op.resultGroups) 370 if (containsPosition(result.definition)) 371 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 372 for (const auto &symUse : op.symbolUses) { 373 if (contains(symUse, posLoc)) { 374 locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri)); 375 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 376 } 377 } 378 } 379 380 // Check all definitions related to blocks. 381 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 382 if (containsPosition(block.definition)) 383 return; 384 for (const AsmParserState::SMDefinition &arg : block.arguments) 385 if (containsPosition(arg)) 386 return; 387 } 388 } 389 390 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri, 391 const lsp::Position &pos, 392 std::vector<lsp::Location> &references) { 393 // Functor used to append all of the definitions/uses of the given SM 394 // definition to the reference list. 395 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) { 396 references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); 397 for (const llvm::SMRange &use : def.uses) 398 references.push_back(getLocationFromLoc(sourceMgr, use, uri)); 399 }; 400 401 llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, pos); 402 403 // Check all definitions related to operations. 404 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 405 if (contains(op.loc, posLoc)) { 406 for (const auto &result : op.resultGroups) 407 appendSMDef(result.definition); 408 for (const auto &symUse : op.symbolUses) 409 if (contains(symUse, posLoc)) 410 references.push_back(getLocationFromLoc(sourceMgr, symUse, uri)); 411 return; 412 } 413 for (const auto &result : op.resultGroups) 414 if (isDefOrUse(result.definition, posLoc)) 415 return appendSMDef(result.definition); 416 for (const auto &symUse : op.symbolUses) { 417 if (!contains(symUse, posLoc)) 418 continue; 419 for (const auto &symUse : op.symbolUses) 420 references.push_back(getLocationFromLoc(sourceMgr, symUse, uri)); 421 return; 422 } 423 } 424 425 // Check all definitions related to blocks. 426 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 427 if (isDefOrUse(block.definition, posLoc)) 428 return appendSMDef(block.definition); 429 430 for (const AsmParserState::SMDefinition &arg : block.arguments) 431 if (isDefOrUse(arg, posLoc)) 432 return appendSMDef(arg); 433 } 434 } 435 436 //===----------------------------------------------------------------------===// 437 // MLIRDocument: Hover 438 //===----------------------------------------------------------------------===// 439 440 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri, 441 const lsp::Position &hoverPos) { 442 llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos); 443 llvm::SMRange hoverRange; 444 445 // Check for Hovers on operations and results. 446 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 447 // Check if the position points at this operation. 448 if (contains(op.loc, posLoc)) 449 return buildHoverForOperation(op.loc, op); 450 451 // Check if the position points at the symbol name. 452 for (auto &use : op.symbolUses) 453 if (contains(use, posLoc)) 454 return buildHoverForOperation(use, op); 455 456 // Check if the position points at a result group. 457 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) { 458 const auto &result = op.resultGroups[i]; 459 if (!isDefOrUse(result.definition, posLoc, &hoverRange)) 460 continue; 461 462 // Get the range of results covered by the over position. 463 unsigned resultStart = result.startIndex; 464 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults() 465 : op.resultGroups[i + 1].startIndex; 466 return buildHoverForOperationResult(hoverRange, op.op, resultStart, 467 resultEnd, posLoc); 468 } 469 } 470 471 // Check to see if the hover is over a block argument. 472 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 473 if (isDefOrUse(block.definition, posLoc, &hoverRange)) 474 return buildHoverForBlock(hoverRange, block); 475 476 for (const auto &arg : llvm::enumerate(block.arguments)) { 477 if (!isDefOrUse(arg.value(), posLoc, &hoverRange)) 478 continue; 479 480 return buildHoverForBlockArgument( 481 hoverRange, block.block->getArgument(arg.index()), block); 482 } 483 } 484 return llvm::None; 485 } 486 487 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation( 488 llvm::SMRange hoverRange, const AsmParserState::OperationDefinition &op) { 489 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 490 llvm::raw_string_ostream os(hover.contents.value); 491 492 // Add the operation name to the hover. 493 os << "\"" << op.op->getName() << "\""; 494 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op)) 495 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << ""; 496 os << "\n\n"; 497 498 os << "Generic Form:\n\n```mlir\n"; 499 500 // Temporary drop the regions of this operation so that they don't get 501 // printed in the output. This helps keeps the size of the output hover 502 // small. 503 SmallVector<std::unique_ptr<Region>> regions; 504 for (Region ®ion : op.op->getRegions()) { 505 regions.emplace_back(std::make_unique<Region>()); 506 regions.back()->takeBody(region); 507 } 508 509 op.op->print( 510 os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); 511 os << "\n```\n"; 512 513 // Move the regions back to the current operation. 514 for (Region ®ion : op.op->getRegions()) 515 region.takeBody(*regions.back()); 516 517 return hover; 518 } 519 520 lsp::Hover MLIRDocument::buildHoverForOperationResult(llvm::SMRange hoverRange, 521 Operation *op, 522 unsigned resultStart, 523 unsigned resultEnd, 524 llvm::SMLoc posLoc) { 525 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 526 llvm::raw_string_ostream os(hover.contents.value); 527 528 // Add the parent operation name to the hover. 529 os << "Operation: \"" << op->getName() << "\"\n\n"; 530 531 // Check to see if the location points to a specific result within the 532 // group. 533 if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) { 534 if ((resultStart + *resultNumber) < resultEnd) { 535 resultStart += *resultNumber; 536 resultEnd = resultStart + 1; 537 } 538 } 539 540 // Add the range of results and their types to the hover info. 541 if ((resultStart + 1) == resultEnd) { 542 os << "Result #" << resultStart << "\n\n" 543 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n"; 544 } else { 545 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n" 546 << "Types: "; 547 llvm::interleaveComma( 548 op->getResults().slice(resultStart, resultEnd), os, 549 [&](Value result) { os << "`" << result.getType() << "`"; }); 550 } 551 552 return hover; 553 } 554 555 lsp::Hover 556 MLIRDocument::buildHoverForBlock(llvm::SMRange hoverRange, 557 const AsmParserState::BlockDefinition &block) { 558 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 559 llvm::raw_string_ostream os(hover.contents.value); 560 561 // Print the given block to the hover output stream. 562 auto printBlockToHover = [&](Block *newBlock) { 563 if (const auto *def = asmState.getBlockDef(newBlock)) 564 printDefBlockName(os, *def); 565 else 566 printDefBlockName(os, newBlock); 567 }; 568 569 // Display the parent operation, block number, predecessors, and successors. 570 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 571 << "Block #" << getBlockNumber(block.block) << "\n\n"; 572 if (!block.block->hasNoPredecessors()) { 573 os << "Predecessors: "; 574 llvm::interleaveComma(block.block->getPredecessors(), os, 575 printBlockToHover); 576 os << "\n\n"; 577 } 578 if (!block.block->hasNoSuccessors()) { 579 os << "Successors: "; 580 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover); 581 os << "\n\n"; 582 } 583 584 return hover; 585 } 586 587 lsp::Hover MLIRDocument::buildHoverForBlockArgument( 588 llvm::SMRange hoverRange, BlockArgument arg, 589 const AsmParserState::BlockDefinition &block) { 590 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 591 llvm::raw_string_ostream os(hover.contents.value); 592 593 // Display the parent operation, block, the argument number, and the type. 594 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 595 << "Block: "; 596 printDefBlockName(os, block); 597 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n" 598 << "Type: `" << arg.getType() << "`\n\n"; 599 600 return hover; 601 } 602 603 //===----------------------------------------------------------------------===// 604 // MLIRDocument: Document Symbols 605 //===----------------------------------------------------------------------===// 606 607 void MLIRDocument::findDocumentSymbols( 608 std::vector<lsp::DocumentSymbol> &symbols) { 609 for (Operation &op : parsedIR) 610 findDocumentSymbols(&op, symbols); 611 } 612 613 void MLIRDocument::findDocumentSymbols( 614 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) { 615 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols; 616 617 // Check for the source information of this operation. 618 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) { 619 // If this operation defines a symbol, record it. 620 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) { 621 symbols.emplace_back(symbol.getName(), 622 op->hasTrait<OpTrait::FunctionLike>() 623 ? lsp::SymbolKind::Function 624 : lsp::SymbolKind::Class, 625 getRangeFromLoc(sourceMgr, def->scopeLoc), 626 getRangeFromLoc(sourceMgr, def->loc)); 627 childSymbols = &symbols.back().children; 628 629 } else if (op->hasTrait<OpTrait::SymbolTable>()) { 630 // Otherwise, if this is a symbol table push an anonymous document symbol. 631 symbols.emplace_back("<" + op->getName().getStringRef() + ">", 632 lsp::SymbolKind::Namespace, 633 getRangeFromLoc(sourceMgr, def->scopeLoc), 634 getRangeFromLoc(sourceMgr, def->loc)); 635 childSymbols = &symbols.back().children; 636 } 637 } 638 639 // Recurse into the regions of this operation. 640 if (!op->getNumRegions()) 641 return; 642 for (Region ®ion : op->getRegions()) 643 for (Operation &childOp : region.getOps()) 644 findDocumentSymbols(&childOp, *childSymbols); 645 } 646 647 //===----------------------------------------------------------------------===// 648 // MLIRTextFileChunk 649 //===----------------------------------------------------------------------===// 650 651 namespace { 652 /// This class represents a single chunk of an MLIR text file. 653 struct MLIRTextFileChunk { 654 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset, 655 const lsp::URIForFile &uri, StringRef contents, 656 std::vector<lsp::Diagnostic> &diagnostics) 657 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {} 658 659 /// Adjust the line number of the given range to anchor at the beginning of 660 /// the file, instead of the beginning of this chunk. 661 void adjustLocForChunkOffset(lsp::Range &range) { 662 adjustLocForChunkOffset(range.start); 663 adjustLocForChunkOffset(range.end); 664 } 665 /// Adjust the line number of the given position to anchor at the beginning of 666 /// the file, instead of the beginning of this chunk. 667 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } 668 669 /// The line offset of this chunk from the beginning of the file. 670 uint64_t lineOffset; 671 /// The document referred to by this chunk. 672 MLIRDocument document; 673 }; 674 } // namespace 675 676 //===----------------------------------------------------------------------===// 677 // MLIRTextFile 678 //===----------------------------------------------------------------------===// 679 680 namespace { 681 /// This class represents a text file containing one or more MLIR documents. 682 class MLIRTextFile { 683 public: 684 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 685 int64_t version, DialectRegistry ®istry, 686 std::vector<lsp::Diagnostic> &diagnostics); 687 688 /// Return the current version of this text file. 689 int64_t getVersion() const { return version; } 690 691 //===--------------------------------------------------------------------===// 692 // LSP Queries 693 //===--------------------------------------------------------------------===// 694 695 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, 696 std::vector<lsp::Location> &locations); 697 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, 698 std::vector<lsp::Location> &references); 699 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 700 lsp::Position hoverPos); 701 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 702 703 private: 704 /// Find the MLIR document that contains the given position, and update the 705 /// position to be anchored at the start of the found chunk instead of the 706 /// beginning of the file. 707 MLIRTextFileChunk &getChunkFor(lsp::Position &pos); 708 709 /// The context used to hold the state contained by the parsed document. 710 MLIRContext context; 711 712 /// The full string contents of the file. 713 std::string contents; 714 715 /// The version of this file. 716 int64_t version; 717 718 /// The number of lines in the file. 719 int64_t totalNumLines; 720 721 /// The chunks of this file. The order of these chunks is the order in which 722 /// they appear in the text file. 723 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks; 724 }; 725 } // namespace 726 727 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 728 int64_t version, DialectRegistry ®istry, 729 std::vector<lsp::Diagnostic> &diagnostics) 730 : context(registry, MLIRContext::Threading::DISABLED), 731 contents(fileContents.str()), version(version), totalNumLines(0) { 732 context.allowUnregisteredDialects(); 733 734 // Split the file into separate MLIR documents. 735 // TODO: Find a way to share the split file marker with other tools. We don't 736 // want to use `splitAndProcessBuffer` here, but we do want to make sure this 737 // marker doesn't go out of sync. 738 SmallVector<StringRef, 8> subContents; 739 StringRef(contents).split(subContents, "// -----"); 740 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>( 741 context, /*lineOffset=*/0, uri, subContents.front(), diagnostics)); 742 743 uint64_t lineOffset = subContents.front().count('\n'); 744 for (StringRef docContents : llvm::drop_begin(subContents)) { 745 unsigned currentNumDiags = diagnostics.size(); 746 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri, 747 docContents, diagnostics); 748 lineOffset += docContents.count('\n'); 749 750 // Adjust locations used in diagnostics to account for the offset from the 751 // beginning of the file. 752 for (lsp::Diagnostic &diag : 753 llvm::drop_begin(diagnostics, currentNumDiags)) { 754 chunk->adjustLocForChunkOffset(diag.range); 755 756 if (!diag.relatedInformation) 757 continue; 758 for (auto &it : *diag.relatedInformation) 759 if (it.location.uri == uri) 760 chunk->adjustLocForChunkOffset(it.location.range); 761 } 762 chunks.emplace_back(std::move(chunk)); 763 } 764 totalNumLines = lineOffset; 765 } 766 767 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri, 768 lsp::Position defPos, 769 std::vector<lsp::Location> &locations) { 770 MLIRTextFileChunk &chunk = getChunkFor(defPos); 771 chunk.document.getLocationsOf(uri, defPos, locations); 772 773 // Adjust any locations within this file for the offset of this chunk. 774 if (chunk.lineOffset == 0) 775 return; 776 for (lsp::Location &loc : locations) 777 if (loc.uri == uri) 778 chunk.adjustLocForChunkOffset(loc.range); 779 } 780 781 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri, 782 lsp::Position pos, 783 std::vector<lsp::Location> &references) { 784 MLIRTextFileChunk &chunk = getChunkFor(pos); 785 chunk.document.findReferencesOf(uri, pos, references); 786 787 // Adjust any locations within this file for the offset of this chunk. 788 if (chunk.lineOffset == 0) 789 return; 790 for (lsp::Location &loc : references) 791 if (loc.uri == uri) 792 chunk.adjustLocForChunkOffset(loc.range); 793 } 794 795 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri, 796 lsp::Position hoverPos) { 797 MLIRTextFileChunk &chunk = getChunkFor(hoverPos); 798 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos); 799 800 // Adjust any locations within this file for the offset of this chunk. 801 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) 802 chunk.adjustLocForChunkOffset(*hoverInfo->range); 803 return hoverInfo; 804 } 805 806 void MLIRTextFile::findDocumentSymbols( 807 std::vector<lsp::DocumentSymbol> &symbols) { 808 if (chunks.size() == 1) 809 return chunks.front()->document.findDocumentSymbols(symbols); 810 811 // If there are multiple chunks in this file, we create top-level symbols for 812 // each chunk. 813 for (unsigned i = 0, e = chunks.size(); i < e; ++i) { 814 MLIRTextFileChunk &chunk = *chunks[i]; 815 lsp::Position startPos(chunk.lineOffset); 816 lsp::Position endPos((i == e - 1) ? totalNumLines - 1 817 : chunks[i + 1]->lineOffset); 818 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">", 819 lsp::SymbolKind::Namespace, 820 /*range=*/lsp::Range(startPos, endPos), 821 /*selectionRange=*/lsp::Range(startPos)); 822 chunk.document.findDocumentSymbols(symbol.children); 823 824 // Fixup the locations of document symbols within this chunk. 825 if (i != 0) { 826 SmallVector<lsp::DocumentSymbol *> symbolsToFix; 827 for (lsp::DocumentSymbol &childSymbol : symbol.children) 828 symbolsToFix.push_back(&childSymbol); 829 830 while (!symbolsToFix.empty()) { 831 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); 832 chunk.adjustLocForChunkOffset(symbol->range); 833 chunk.adjustLocForChunkOffset(symbol->selectionRange); 834 835 for (lsp::DocumentSymbol &childSymbol : symbol->children) 836 symbolsToFix.push_back(&childSymbol); 837 } 838 } 839 840 // Push the symbol for this chunk. 841 symbols.emplace_back(std::move(symbol)); 842 } 843 } 844 845 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { 846 if (chunks.size() == 1) 847 return *chunks.front(); 848 849 // Search for the first chunk with a greater line offset, the previous chunk 850 // is the one that contains `pos`. 851 auto it = llvm::upper_bound( 852 chunks, pos, [](const lsp::Position &pos, const auto &chunk) { 853 return static_cast<uint64_t>(pos.line) < chunk->lineOffset; 854 }); 855 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it); 856 pos.line -= chunk.lineOffset; 857 return chunk; 858 } 859 860 //===----------------------------------------------------------------------===// 861 // MLIRServer::Impl 862 //===----------------------------------------------------------------------===// 863 864 struct lsp::MLIRServer::Impl { 865 Impl(DialectRegistry ®istry) : registry(registry) {} 866 867 /// The registry containing dialects that can be recognized in parsed .mlir 868 /// files. 869 DialectRegistry ®istry; 870 871 /// The files held by the server, mapped by their URI file name. 872 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files; 873 }; 874 875 //===----------------------------------------------------------------------===// 876 // MLIRServer 877 //===----------------------------------------------------------------------===// 878 879 lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) 880 : impl(std::make_unique<Impl>(registry)) {} 881 lsp::MLIRServer::~MLIRServer() = default; 882 883 void lsp::MLIRServer::addOrUpdateDocument( 884 const URIForFile &uri, StringRef contents, int64_t version, 885 std::vector<Diagnostic> &diagnostics) { 886 impl->files[uri.file()] = std::make_unique<MLIRTextFile>( 887 uri, contents, version, impl->registry, diagnostics); 888 } 889 890 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) { 891 auto it = impl->files.find(uri.file()); 892 if (it == impl->files.end()) 893 return llvm::None; 894 895 int64_t version = it->second->getVersion(); 896 impl->files.erase(it); 897 return version; 898 } 899 900 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, 901 const Position &defPos, 902 std::vector<Location> &locations) { 903 auto fileIt = impl->files.find(uri.file()); 904 if (fileIt != impl->files.end()) 905 fileIt->second->getLocationsOf(uri, defPos, locations); 906 } 907 908 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, 909 const Position &pos, 910 std::vector<Location> &references) { 911 auto fileIt = impl->files.find(uri.file()); 912 if (fileIt != impl->files.end()) 913 fileIt->second->findReferencesOf(uri, pos, references); 914 } 915 916 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri, 917 const Position &hoverPos) { 918 auto fileIt = impl->files.find(uri.file()); 919 if (fileIt != impl->files.end()) 920 return fileIt->second->findHover(uri, hoverPos); 921 return llvm::None; 922 } 923 924 void lsp::MLIRServer::findDocumentSymbols( 925 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) { 926 auto fileIt = impl->files.find(uri.file()); 927 if (fileIt != impl->files.end()) 928 fileIt->second->findDocumentSymbols(symbols); 929 } 930