1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MLIRServer.h" 10 #include "lsp/Logging.h" 11 #include "lsp/Protocol.h" 12 #include "mlir/IR/FunctionInterfaces.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Parser/AsmParserState.h" 15 #include "mlir/Parser/Parser.h" 16 #include "llvm/Support/SourceMgr.h" 17 18 using namespace mlir; 19 20 /// Returns a language server position for the given source location. 21 static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, SMLoc loc) { 22 std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc); 23 lsp::Position pos; 24 pos.line = lineAndCol.first - 1; 25 pos.character = lineAndCol.second - 1; 26 return pos; 27 } 28 29 /// Returns a source location from the given language server position. 30 static SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) { 31 return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), pos.line + 1, 32 pos.character); 33 } 34 35 /// Returns a language server range for the given source range. 36 static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, SMRange range) { 37 return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)}; 38 } 39 40 /// Returns a language server location from the given source range. 41 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, 42 SMRange range, 43 const lsp::URIForFile &uri) { 44 return lsp::Location{uri, getRangeFromLoc(mgr, range)}; 45 } 46 47 /// Returns a language server location from the given MLIR file location. 48 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) { 49 llvm::Expected<lsp::URIForFile> sourceURI = 50 lsp::URIForFile::fromFile(loc.getFilename()); 51 if (!sourceURI) { 52 lsp::Logger::error("Failed to create URI for file `{0}`: {1}", 53 loc.getFilename(), 54 llvm::toString(sourceURI.takeError())); 55 return llvm::None; 56 } 57 58 lsp::Position position; 59 position.line = loc.getLine() - 1; 60 position.character = loc.getColumn(); 61 return lsp::Location{*sourceURI, lsp::Range(position)}; 62 } 63 64 /// Returns a language server location from the given MLIR location, or None if 65 /// one couldn't be created. `uri` is an optional additional filter that, when 66 /// present, is used to filter sub locations that do not share the same uri. 67 static Optional<lsp::Location> 68 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, 69 const lsp::URIForFile *uri = nullptr) { 70 Optional<lsp::Location> location; 71 loc->walk([&](Location nestedLoc) { 72 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 73 if (!fileLoc) 74 return WalkResult::advance(); 75 76 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 77 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { 78 location = *sourceLoc; 79 SMLoc loc = sourceMgr.FindLocForLineAndColumn( 80 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn()); 81 82 // Use range of potential identifier starting at location, else length 1 83 // range. 84 location->range.end.character += 1; 85 if (Optional<SMRange> range = 86 AsmParserState::convertIdLocToRange(loc)) { 87 auto lineCol = sourceMgr.getLineAndColumn(range->End); 88 location->range.end.character = 89 std::max(fileLoc.getColumn() + 1, lineCol.second - 1); 90 } 91 return WalkResult::interrupt(); 92 } 93 return WalkResult::advance(); 94 }); 95 return location; 96 } 97 98 /// Collect all of the locations from the given MLIR location that are not 99 /// contained within the given URI. 100 static void collectLocationsFromLoc(Location loc, 101 std::vector<lsp::Location> &locations, 102 const lsp::URIForFile &uri) { 103 SetVector<Location> visitedLocs; 104 loc->walk([&](Location nestedLoc) { 105 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 106 if (!fileLoc || !visitedLocs.insert(nestedLoc)) 107 return WalkResult::advance(); 108 109 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 110 if (sourceLoc && sourceLoc->uri != uri) 111 locations.push_back(*sourceLoc); 112 return WalkResult::advance(); 113 }); 114 } 115 116 /// Returns true if the given range contains the given source location. Note 117 /// that this has slightly different behavior than SMRange because it is 118 /// inclusive of the end location. 119 static bool contains(SMRange range, SMLoc loc) { 120 return range.Start.getPointer() <= loc.getPointer() && 121 loc.getPointer() <= range.End.getPointer(); 122 } 123 124 /// Returns true if the given location is contained by the definition or one of 125 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to 126 /// the range within `def` that the provided `loc` overlapped with. 127 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, 128 SMRange *overlappedRange = nullptr) { 129 // Check the main definition. 130 if (contains(def.loc, loc)) { 131 if (overlappedRange) 132 *overlappedRange = def.loc; 133 return true; 134 } 135 136 // Check the uses. 137 const auto *useIt = llvm::find_if(def.uses, [&](const SMRange &range) { 138 return contains(range, loc); 139 }); 140 if (useIt != def.uses.end()) { 141 if (overlappedRange) 142 *overlappedRange = *useIt; 143 return true; 144 } 145 return false; 146 } 147 148 /// Given a location pointing to a result, return the result number it refers 149 /// to or None if it refers to all of the results. 150 static Optional<unsigned> getResultNumberFromLoc(SMLoc loc) { 151 // Skip all of the identifier characters. 152 auto isIdentifierChar = [](char c) { 153 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' || 154 c == '-'; 155 }; 156 const char *curPtr = loc.getPointer(); 157 while (isIdentifierChar(*curPtr)) 158 ++curPtr; 159 160 // Check to see if this location indexes into the result group, via `#`. If it 161 // doesn't, we can't extract a sub result number. 162 if (*curPtr != '#') 163 return llvm::None; 164 165 // Compute the sub result number from the remaining portion of the string. 166 const char *numberStart = ++curPtr; 167 while (llvm::isDigit(*curPtr)) 168 ++curPtr; 169 StringRef numberStr(numberStart, curPtr - numberStart); 170 unsigned resultNumber = 0; 171 return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>() 172 : resultNumber; 173 } 174 175 /// Given a source location range, return the text covered by the given range. 176 /// If the range is invalid, returns None. 177 static Optional<StringRef> getTextFromRange(SMRange range) { 178 if (!range.isValid()) 179 return None; 180 const char *startPtr = range.Start.getPointer(); 181 return StringRef(startPtr, range.End.getPointer() - startPtr); 182 } 183 184 /// Given a block, return its position in its parent region. 185 static unsigned getBlockNumber(Block *block) { 186 return std::distance(block->getParent()->begin(), block->getIterator()); 187 } 188 189 /// Given a block and source location, print the source name of the block to the 190 /// given output stream. 191 static void printDefBlockName(raw_ostream &os, Block *block, 192 SMRange loc = {}) { 193 // Try to extract a name from the source location. 194 Optional<StringRef> text = getTextFromRange(loc); 195 if (text && text->startswith("^")) { 196 os << *text; 197 return; 198 } 199 200 // Otherwise, we don't have a name so print the block number. 201 os << "<Block #" << getBlockNumber(block) << ">"; 202 } 203 static void printDefBlockName(raw_ostream &os, 204 const AsmParserState::BlockDefinition &def) { 205 printDefBlockName(os, def.block, def.definition.loc); 206 } 207 208 /// Convert the given MLIR diagnostic to the LSP form. 209 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, 210 Diagnostic &diag, 211 const lsp::URIForFile &uri) { 212 lsp::Diagnostic lspDiag; 213 lspDiag.source = "mlir"; 214 215 // Note: Right now all of the diagnostics are treated as parser issues, but 216 // some are parser and some are verifier. 217 lspDiag.category = "Parse Error"; 218 219 // Try to grab a file location for this diagnostic. 220 // TODO: For simplicity, we just grab the first one. It may be likely that we 221 // will need a more interesting heuristic here.' 222 Optional<lsp::Location> lspLocation = 223 getLocationFromLoc(sourceMgr, diag.getLocation(), &uri); 224 if (lspLocation) 225 lspDiag.range = lspLocation->range; 226 227 // Convert the severity for the diagnostic. 228 switch (diag.getSeverity()) { 229 case DiagnosticSeverity::Note: 230 llvm_unreachable("expected notes to be handled separately"); 231 case DiagnosticSeverity::Warning: 232 lspDiag.severity = lsp::DiagnosticSeverity::Warning; 233 break; 234 case DiagnosticSeverity::Error: 235 lspDiag.severity = lsp::DiagnosticSeverity::Error; 236 break; 237 case DiagnosticSeverity::Remark: 238 lspDiag.severity = lsp::DiagnosticSeverity::Information; 239 break; 240 } 241 lspDiag.message = diag.str(); 242 243 // Attach any notes to the main diagnostic as related information. 244 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; 245 for (Diagnostic ¬e : diag.getNotes()) { 246 lsp::Location noteLoc; 247 if (Optional<lsp::Location> loc = 248 getLocationFromLoc(sourceMgr, note.getLocation())) 249 noteLoc = *loc; 250 else 251 noteLoc.uri = uri; 252 relatedDiags.emplace_back(noteLoc, note.str()); 253 } 254 if (!relatedDiags.empty()) 255 lspDiag.relatedInformation = std::move(relatedDiags); 256 257 return lspDiag; 258 } 259 260 //===----------------------------------------------------------------------===// 261 // MLIRDocument 262 //===----------------------------------------------------------------------===// 263 264 namespace { 265 /// This class represents all of the information pertaining to a specific MLIR 266 /// document. 267 struct MLIRDocument { 268 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 269 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics); 270 MLIRDocument(const MLIRDocument &) = delete; 271 MLIRDocument &operator=(const MLIRDocument &) = delete; 272 273 //===--------------------------------------------------------------------===// 274 // Definitions and References 275 //===--------------------------------------------------------------------===// 276 277 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, 278 std::vector<lsp::Location> &locations); 279 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, 280 std::vector<lsp::Location> &references); 281 282 //===--------------------------------------------------------------------===// 283 // Hover 284 //===--------------------------------------------------------------------===// 285 286 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 287 const lsp::Position &hoverPos); 288 Optional<lsp::Hover> 289 buildHoverForOperation(SMRange hoverRange, 290 const AsmParserState::OperationDefinition &op); 291 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, 292 Operation *op, unsigned resultStart, 293 unsigned resultEnd, 294 SMLoc posLoc); 295 lsp::Hover buildHoverForBlock(SMRange hoverRange, 296 const AsmParserState::BlockDefinition &block); 297 lsp::Hover 298 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg, 299 const AsmParserState::BlockDefinition &block); 300 301 //===--------------------------------------------------------------------===// 302 // Document Symbols 303 //===--------------------------------------------------------------------===// 304 305 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 306 void findDocumentSymbols(Operation *op, 307 std::vector<lsp::DocumentSymbol> &symbols); 308 309 //===--------------------------------------------------------------------===// 310 // Fields 311 //===--------------------------------------------------------------------===// 312 313 /// The high level parser state used to find definitions and references within 314 /// the source file. 315 AsmParserState asmState; 316 317 /// The container for the IR parsed from the input file. 318 Block parsedIR; 319 320 /// The source manager containing the contents of the input file. 321 llvm::SourceMgr sourceMgr; 322 }; 323 } // namespace 324 325 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 326 StringRef contents, 327 std::vector<lsp::Diagnostic> &diagnostics) { 328 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { 329 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri)); 330 }); 331 332 // Try to parsed the given IR string. 333 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); 334 if (!memBuffer) { 335 lsp::Logger::error("Failed to create memory buffer for file", uri.file()); 336 return; 337 } 338 339 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 340 if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, 341 &asmState))) { 342 // If parsing failed, clear out any of the current state. 343 parsedIR.clear(); 344 asmState = AsmParserState(); 345 return; 346 } 347 } 348 349 //===----------------------------------------------------------------------===// 350 // MLIRDocument: Definitions and References 351 //===----------------------------------------------------------------------===// 352 353 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, 354 const lsp::Position &defPos, 355 std::vector<lsp::Location> &locations) { 356 SMLoc posLoc = getPosFromLoc(sourceMgr, defPos); 357 358 // Functor used to check if an SM definition contains the position. 359 auto containsPosition = [&](const AsmParserState::SMDefinition &def) { 360 if (!isDefOrUse(def, posLoc)) 361 return false; 362 locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); 363 return true; 364 }; 365 366 // Check all definitions related to operations. 367 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 368 if (contains(op.loc, posLoc)) 369 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 370 for (const auto &result : op.resultGroups) 371 if (containsPosition(result.definition)) 372 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 373 for (const auto &symUse : op.symbolUses) { 374 if (contains(symUse, posLoc)) { 375 locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri)); 376 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 377 } 378 } 379 } 380 381 // Check all definitions related to blocks. 382 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 383 if (containsPosition(block.definition)) 384 return; 385 for (const AsmParserState::SMDefinition &arg : block.arguments) 386 if (containsPosition(arg)) 387 return; 388 } 389 } 390 391 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri, 392 const lsp::Position &pos, 393 std::vector<lsp::Location> &references) { 394 // Functor used to append all of the definitions/uses of the given SM 395 // definition to the reference list. 396 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) { 397 references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); 398 for (const SMRange &use : def.uses) 399 references.push_back(getLocationFromLoc(sourceMgr, use, uri)); 400 }; 401 402 SMLoc posLoc = getPosFromLoc(sourceMgr, pos); 403 404 // Check all definitions related to operations. 405 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 406 if (contains(op.loc, posLoc)) { 407 for (const auto &result : op.resultGroups) 408 appendSMDef(result.definition); 409 for (const auto &symUse : op.symbolUses) 410 if (contains(symUse, posLoc)) 411 references.push_back(getLocationFromLoc(sourceMgr, symUse, uri)); 412 return; 413 } 414 for (const auto &result : op.resultGroups) 415 if (isDefOrUse(result.definition, posLoc)) 416 return appendSMDef(result.definition); 417 for (const auto &symUse : op.symbolUses) { 418 if (!contains(symUse, posLoc)) 419 continue; 420 for (const auto &symUse : op.symbolUses) 421 references.push_back(getLocationFromLoc(sourceMgr, symUse, uri)); 422 return; 423 } 424 } 425 426 // Check all definitions related to blocks. 427 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 428 if (isDefOrUse(block.definition, posLoc)) 429 return appendSMDef(block.definition); 430 431 for (const AsmParserState::SMDefinition &arg : block.arguments) 432 if (isDefOrUse(arg, posLoc)) 433 return appendSMDef(arg); 434 } 435 } 436 437 //===----------------------------------------------------------------------===// 438 // MLIRDocument: Hover 439 //===----------------------------------------------------------------------===// 440 441 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri, 442 const lsp::Position &hoverPos) { 443 SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos); 444 SMRange hoverRange; 445 446 // Check for Hovers on operations and results. 447 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 448 // Check if the position points at this operation. 449 if (contains(op.loc, posLoc)) 450 return buildHoverForOperation(op.loc, op); 451 452 // Check if the position points at the symbol name. 453 for (auto &use : op.symbolUses) 454 if (contains(use, posLoc)) 455 return buildHoverForOperation(use, op); 456 457 // Check if the position points at a result group. 458 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) { 459 const auto &result = op.resultGroups[i]; 460 if (!isDefOrUse(result.definition, posLoc, &hoverRange)) 461 continue; 462 463 // Get the range of results covered by the over position. 464 unsigned resultStart = result.startIndex; 465 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults() 466 : op.resultGroups[i + 1].startIndex; 467 return buildHoverForOperationResult(hoverRange, op.op, resultStart, 468 resultEnd, posLoc); 469 } 470 } 471 472 // Check to see if the hover is over a block argument. 473 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 474 if (isDefOrUse(block.definition, posLoc, &hoverRange)) 475 return buildHoverForBlock(hoverRange, block); 476 477 for (const auto &arg : llvm::enumerate(block.arguments)) { 478 if (!isDefOrUse(arg.value(), posLoc, &hoverRange)) 479 continue; 480 481 return buildHoverForBlockArgument( 482 hoverRange, block.block->getArgument(arg.index()), block); 483 } 484 } 485 return llvm::None; 486 } 487 488 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation( 489 SMRange hoverRange, const AsmParserState::OperationDefinition &op) { 490 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 491 llvm::raw_string_ostream os(hover.contents.value); 492 493 // Add the operation name to the hover. 494 os << "\"" << op.op->getName() << "\""; 495 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op)) 496 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << ""; 497 os << "\n\n"; 498 499 os << "Generic Form:\n\n```mlir\n"; 500 501 // Temporary drop the regions of this operation so that they don't get 502 // printed in the output. This helps keeps the size of the output hover 503 // small. 504 SmallVector<std::unique_ptr<Region>> regions; 505 for (Region ®ion : op.op->getRegions()) { 506 regions.emplace_back(std::make_unique<Region>()); 507 regions.back()->takeBody(region); 508 } 509 510 op.op->print( 511 os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); 512 os << "\n```\n"; 513 514 // Move the regions back to the current operation. 515 for (Region ®ion : op.op->getRegions()) 516 region.takeBody(*regions.back()); 517 518 return hover; 519 } 520 521 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange, 522 Operation *op, 523 unsigned resultStart, 524 unsigned resultEnd, 525 SMLoc posLoc) { 526 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 527 llvm::raw_string_ostream os(hover.contents.value); 528 529 // Add the parent operation name to the hover. 530 os << "Operation: \"" << op->getName() << "\"\n\n"; 531 532 // Check to see if the location points to a specific result within the 533 // group. 534 if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) { 535 if ((resultStart + *resultNumber) < resultEnd) { 536 resultStart += *resultNumber; 537 resultEnd = resultStart + 1; 538 } 539 } 540 541 // Add the range of results and their types to the hover info. 542 if ((resultStart + 1) == resultEnd) { 543 os << "Result #" << resultStart << "\n\n" 544 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n"; 545 } else { 546 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n" 547 << "Types: "; 548 llvm::interleaveComma( 549 op->getResults().slice(resultStart, resultEnd), os, 550 [&](Value result) { os << "`" << result.getType() << "`"; }); 551 } 552 553 return hover; 554 } 555 556 lsp::Hover 557 MLIRDocument::buildHoverForBlock(SMRange hoverRange, 558 const AsmParserState::BlockDefinition &block) { 559 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 560 llvm::raw_string_ostream os(hover.contents.value); 561 562 // Print the given block to the hover output stream. 563 auto printBlockToHover = [&](Block *newBlock) { 564 if (const auto *def = asmState.getBlockDef(newBlock)) 565 printDefBlockName(os, *def); 566 else 567 printDefBlockName(os, newBlock); 568 }; 569 570 // Display the parent operation, block number, predecessors, and successors. 571 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 572 << "Block #" << getBlockNumber(block.block) << "\n\n"; 573 if (!block.block->hasNoPredecessors()) { 574 os << "Predecessors: "; 575 llvm::interleaveComma(block.block->getPredecessors(), os, 576 printBlockToHover); 577 os << "\n\n"; 578 } 579 if (!block.block->hasNoSuccessors()) { 580 os << "Successors: "; 581 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover); 582 os << "\n\n"; 583 } 584 585 return hover; 586 } 587 588 lsp::Hover MLIRDocument::buildHoverForBlockArgument( 589 SMRange hoverRange, BlockArgument arg, 590 const AsmParserState::BlockDefinition &block) { 591 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); 592 llvm::raw_string_ostream os(hover.contents.value); 593 594 // Display the parent operation, block, the argument number, and the type. 595 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 596 << "Block: "; 597 printDefBlockName(os, block); 598 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n" 599 << "Type: `" << arg.getType() << "`\n\n"; 600 601 return hover; 602 } 603 604 //===----------------------------------------------------------------------===// 605 // MLIRDocument: Document Symbols 606 //===----------------------------------------------------------------------===// 607 608 void MLIRDocument::findDocumentSymbols( 609 std::vector<lsp::DocumentSymbol> &symbols) { 610 for (Operation &op : parsedIR) 611 findDocumentSymbols(&op, symbols); 612 } 613 614 void MLIRDocument::findDocumentSymbols( 615 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) { 616 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols; 617 618 // Check for the source information of this operation. 619 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) { 620 // If this operation defines a symbol, record it. 621 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) { 622 symbols.emplace_back(symbol.getName(), 623 isa<FunctionOpInterface>(op) 624 ? lsp::SymbolKind::Function 625 : lsp::SymbolKind::Class, 626 getRangeFromLoc(sourceMgr, def->scopeLoc), 627 getRangeFromLoc(sourceMgr, def->loc)); 628 childSymbols = &symbols.back().children; 629 630 } else if (op->hasTrait<OpTrait::SymbolTable>()) { 631 // Otherwise, if this is a symbol table push an anonymous document symbol. 632 symbols.emplace_back("<" + op->getName().getStringRef() + ">", 633 lsp::SymbolKind::Namespace, 634 getRangeFromLoc(sourceMgr, def->scopeLoc), 635 getRangeFromLoc(sourceMgr, def->loc)); 636 childSymbols = &symbols.back().children; 637 } 638 } 639 640 // Recurse into the regions of this operation. 641 if (!op->getNumRegions()) 642 return; 643 for (Region ®ion : op->getRegions()) 644 for (Operation &childOp : region.getOps()) 645 findDocumentSymbols(&childOp, *childSymbols); 646 } 647 648 //===----------------------------------------------------------------------===// 649 // MLIRTextFileChunk 650 //===----------------------------------------------------------------------===// 651 652 namespace { 653 /// This class represents a single chunk of an MLIR text file. 654 struct MLIRTextFileChunk { 655 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset, 656 const lsp::URIForFile &uri, StringRef contents, 657 std::vector<lsp::Diagnostic> &diagnostics) 658 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {} 659 660 /// Adjust the line number of the given range to anchor at the beginning of 661 /// the file, instead of the beginning of this chunk. 662 void adjustLocForChunkOffset(lsp::Range &range) { 663 adjustLocForChunkOffset(range.start); 664 adjustLocForChunkOffset(range.end); 665 } 666 /// Adjust the line number of the given position to anchor at the beginning of 667 /// the file, instead of the beginning of this chunk. 668 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } 669 670 /// The line offset of this chunk from the beginning of the file. 671 uint64_t lineOffset; 672 /// The document referred to by this chunk. 673 MLIRDocument document; 674 }; 675 } // namespace 676 677 //===----------------------------------------------------------------------===// 678 // MLIRTextFile 679 //===----------------------------------------------------------------------===// 680 681 namespace { 682 /// This class represents a text file containing one or more MLIR documents. 683 class MLIRTextFile { 684 public: 685 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 686 int64_t version, DialectRegistry ®istry, 687 std::vector<lsp::Diagnostic> &diagnostics); 688 689 /// Return the current version of this text file. 690 int64_t getVersion() const { return version; } 691 692 //===--------------------------------------------------------------------===// 693 // LSP Queries 694 //===--------------------------------------------------------------------===// 695 696 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, 697 std::vector<lsp::Location> &locations); 698 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, 699 std::vector<lsp::Location> &references); 700 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 701 lsp::Position hoverPos); 702 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 703 704 private: 705 /// Find the MLIR document that contains the given position, and update the 706 /// position to be anchored at the start of the found chunk instead of the 707 /// beginning of the file. 708 MLIRTextFileChunk &getChunkFor(lsp::Position &pos); 709 710 /// The context used to hold the state contained by the parsed document. 711 MLIRContext context; 712 713 /// The full string contents of the file. 714 std::string contents; 715 716 /// The version of this file. 717 int64_t version; 718 719 /// The number of lines in the file. 720 int64_t totalNumLines = 0; 721 722 /// The chunks of this file. The order of these chunks is the order in which 723 /// they appear in the text file. 724 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks; 725 }; 726 } // namespace 727 728 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 729 int64_t version, DialectRegistry ®istry, 730 std::vector<lsp::Diagnostic> &diagnostics) 731 : context(registry, MLIRContext::Threading::DISABLED), 732 contents(fileContents.str()), version(version) { 733 context.allowUnregisteredDialects(); 734 735 // Split the file into separate MLIR documents. 736 // TODO: Find a way to share the split file marker with other tools. We don't 737 // want to use `splitAndProcessBuffer` here, but we do want to make sure this 738 // marker doesn't go out of sync. 739 SmallVector<StringRef, 8> subContents; 740 StringRef(contents).split(subContents, "// -----"); 741 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>( 742 context, /*lineOffset=*/0, uri, subContents.front(), diagnostics)); 743 744 uint64_t lineOffset = subContents.front().count('\n'); 745 for (StringRef docContents : llvm::drop_begin(subContents)) { 746 unsigned currentNumDiags = diagnostics.size(); 747 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri, 748 docContents, diagnostics); 749 lineOffset += docContents.count('\n'); 750 751 // Adjust locations used in diagnostics to account for the offset from the 752 // beginning of the file. 753 for (lsp::Diagnostic &diag : 754 llvm::drop_begin(diagnostics, currentNumDiags)) { 755 chunk->adjustLocForChunkOffset(diag.range); 756 757 if (!diag.relatedInformation) 758 continue; 759 for (auto &it : *diag.relatedInformation) 760 if (it.location.uri == uri) 761 chunk->adjustLocForChunkOffset(it.location.range); 762 } 763 chunks.emplace_back(std::move(chunk)); 764 } 765 totalNumLines = lineOffset; 766 } 767 768 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri, 769 lsp::Position defPos, 770 std::vector<lsp::Location> &locations) { 771 MLIRTextFileChunk &chunk = getChunkFor(defPos); 772 chunk.document.getLocationsOf(uri, defPos, locations); 773 774 // Adjust any locations within this file for the offset of this chunk. 775 if (chunk.lineOffset == 0) 776 return; 777 for (lsp::Location &loc : locations) 778 if (loc.uri == uri) 779 chunk.adjustLocForChunkOffset(loc.range); 780 } 781 782 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri, 783 lsp::Position pos, 784 std::vector<lsp::Location> &references) { 785 MLIRTextFileChunk &chunk = getChunkFor(pos); 786 chunk.document.findReferencesOf(uri, pos, references); 787 788 // Adjust any locations within this file for the offset of this chunk. 789 if (chunk.lineOffset == 0) 790 return; 791 for (lsp::Location &loc : references) 792 if (loc.uri == uri) 793 chunk.adjustLocForChunkOffset(loc.range); 794 } 795 796 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri, 797 lsp::Position hoverPos) { 798 MLIRTextFileChunk &chunk = getChunkFor(hoverPos); 799 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos); 800 801 // Adjust any locations within this file for the offset of this chunk. 802 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) 803 chunk.adjustLocForChunkOffset(*hoverInfo->range); 804 return hoverInfo; 805 } 806 807 void MLIRTextFile::findDocumentSymbols( 808 std::vector<lsp::DocumentSymbol> &symbols) { 809 if (chunks.size() == 1) 810 return chunks.front()->document.findDocumentSymbols(symbols); 811 812 // If there are multiple chunks in this file, we create top-level symbols for 813 // each chunk. 814 for (unsigned i = 0, e = chunks.size(); i < e; ++i) { 815 MLIRTextFileChunk &chunk = *chunks[i]; 816 lsp::Position startPos(chunk.lineOffset); 817 lsp::Position endPos((i == e - 1) ? totalNumLines - 1 818 : chunks[i + 1]->lineOffset); 819 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">", 820 lsp::SymbolKind::Namespace, 821 /*range=*/lsp::Range(startPos, endPos), 822 /*selectionRange=*/lsp::Range(startPos)); 823 chunk.document.findDocumentSymbols(symbol.children); 824 825 // Fixup the locations of document symbols within this chunk. 826 if (i != 0) { 827 SmallVector<lsp::DocumentSymbol *> symbolsToFix; 828 for (lsp::DocumentSymbol &childSymbol : symbol.children) 829 symbolsToFix.push_back(&childSymbol); 830 831 while (!symbolsToFix.empty()) { 832 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); 833 chunk.adjustLocForChunkOffset(symbol->range); 834 chunk.adjustLocForChunkOffset(symbol->selectionRange); 835 836 for (lsp::DocumentSymbol &childSymbol : symbol->children) 837 symbolsToFix.push_back(&childSymbol); 838 } 839 } 840 841 // Push the symbol for this chunk. 842 symbols.emplace_back(std::move(symbol)); 843 } 844 } 845 846 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { 847 if (chunks.size() == 1) 848 return *chunks.front(); 849 850 // Search for the first chunk with a greater line offset, the previous chunk 851 // is the one that contains `pos`. 852 auto it = llvm::upper_bound( 853 chunks, pos, [](const lsp::Position &pos, const auto &chunk) { 854 return static_cast<uint64_t>(pos.line) < chunk->lineOffset; 855 }); 856 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it); 857 pos.line -= chunk.lineOffset; 858 return chunk; 859 } 860 861 //===----------------------------------------------------------------------===// 862 // MLIRServer::Impl 863 //===----------------------------------------------------------------------===// 864 865 struct lsp::MLIRServer::Impl { 866 Impl(DialectRegistry ®istry) : registry(registry) {} 867 868 /// The registry containing dialects that can be recognized in parsed .mlir 869 /// files. 870 DialectRegistry ®istry; 871 872 /// The files held by the server, mapped by their URI file name. 873 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files; 874 }; 875 876 //===----------------------------------------------------------------------===// 877 // MLIRServer 878 //===----------------------------------------------------------------------===// 879 880 lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) 881 : impl(std::make_unique<Impl>(registry)) {} 882 lsp::MLIRServer::~MLIRServer() = default; 883 884 void lsp::MLIRServer::addOrUpdateDocument( 885 const URIForFile &uri, StringRef contents, int64_t version, 886 std::vector<Diagnostic> &diagnostics) { 887 impl->files[uri.file()] = std::make_unique<MLIRTextFile>( 888 uri, contents, version, impl->registry, diagnostics); 889 } 890 891 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) { 892 auto it = impl->files.find(uri.file()); 893 if (it == impl->files.end()) 894 return llvm::None; 895 896 int64_t version = it->second->getVersion(); 897 impl->files.erase(it); 898 return version; 899 } 900 901 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, 902 const Position &defPos, 903 std::vector<Location> &locations) { 904 auto fileIt = impl->files.find(uri.file()); 905 if (fileIt != impl->files.end()) 906 fileIt->second->getLocationsOf(uri, defPos, locations); 907 } 908 909 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, 910 const Position &pos, 911 std::vector<Location> &references) { 912 auto fileIt = impl->files.find(uri.file()); 913 if (fileIt != impl->files.end()) 914 fileIt->second->findReferencesOf(uri, pos, references); 915 } 916 917 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri, 918 const Position &hoverPos) { 919 auto fileIt = impl->files.find(uri.file()); 920 if (fileIt != impl->files.end()) 921 return fileIt->second->findHover(uri, hoverPos); 922 return llvm::None; 923 } 924 925 void lsp::MLIRServer::findDocumentSymbols( 926 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) { 927 auto fileIt = impl->files.find(uri.file()); 928 if (fileIt != impl->files.end()) 929 fileIt->second->findDocumentSymbols(symbols); 930 } 931