1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MLIRServer.h" 10 #include "../lsp-server-support/Logging.h" 11 #include "../lsp-server-support/Protocol.h" 12 #include "../lsp-server-support/SourceMgrUtils.h" 13 #include "mlir/IR/FunctionInterfaces.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/Parser/AsmParserState.h" 16 #include "mlir/Parser/Parser.h" 17 #include "llvm/Support/SourceMgr.h" 18 19 using namespace mlir; 20 21 /// Returns a language server location from the given MLIR file location. 22 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) { 23 llvm::Expected<lsp::URIForFile> sourceURI = 24 lsp::URIForFile::fromFile(loc.getFilename()); 25 if (!sourceURI) { 26 lsp::Logger::error("Failed to create URI for file `{0}`: {1}", 27 loc.getFilename(), 28 llvm::toString(sourceURI.takeError())); 29 return llvm::None; 30 } 31 32 lsp::Position position; 33 position.line = loc.getLine() - 1; 34 position.character = loc.getColumn() ? loc.getColumn() - 1 : 0; 35 return lsp::Location{*sourceURI, lsp::Range(position)}; 36 } 37 38 /// Returns a language server location from the given MLIR location, or None if 39 /// one couldn't be created. `uri` is an optional additional filter that, when 40 /// present, is used to filter sub locations that do not share the same uri. 41 static Optional<lsp::Location> 42 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, 43 const lsp::URIForFile *uri = nullptr) { 44 Optional<lsp::Location> location; 45 loc->walk([&](Location nestedLoc) { 46 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 47 if (!fileLoc) 48 return WalkResult::advance(); 49 50 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 51 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { 52 location = *sourceLoc; 53 SMLoc loc = sourceMgr.FindLocForLineAndColumn( 54 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn()); 55 56 // Use range of potential identifier starting at location, else length 1 57 // range. 58 location->range.end.character += 1; 59 if (Optional<SMRange> range = lsp::convertTokenLocToRange(loc)) { 60 auto lineCol = sourceMgr.getLineAndColumn(range->End); 61 location->range.end.character = 62 std::max(fileLoc.getColumn() + 1, lineCol.second - 1); 63 } 64 return WalkResult::interrupt(); 65 } 66 return WalkResult::advance(); 67 }); 68 return location; 69 } 70 71 /// Collect all of the locations from the given MLIR location that are not 72 /// contained within the given URI. 73 static void collectLocationsFromLoc(Location loc, 74 std::vector<lsp::Location> &locations, 75 const lsp::URIForFile &uri) { 76 SetVector<Location> visitedLocs; 77 loc->walk([&](Location nestedLoc) { 78 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 79 if (!fileLoc || !visitedLocs.insert(nestedLoc)) 80 return WalkResult::advance(); 81 82 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 83 if (sourceLoc && sourceLoc->uri != uri) 84 locations.push_back(*sourceLoc); 85 return WalkResult::advance(); 86 }); 87 } 88 89 /// Returns true if the given range contains the given source location. Note 90 /// that this has slightly different behavior than SMRange because it is 91 /// inclusive of the end location. 92 static bool contains(SMRange range, SMLoc loc) { 93 return range.Start.getPointer() <= loc.getPointer() && 94 loc.getPointer() <= range.End.getPointer(); 95 } 96 97 /// Returns true if the given location is contained by the definition or one of 98 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to 99 /// the range within `def` that the provided `loc` overlapped with. 100 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, 101 SMRange *overlappedRange = nullptr) { 102 // Check the main definition. 103 if (contains(def.loc, loc)) { 104 if (overlappedRange) 105 *overlappedRange = def.loc; 106 return true; 107 } 108 109 // Check the uses. 110 const auto *useIt = llvm::find_if( 111 def.uses, [&](const SMRange &range) { return contains(range, loc); }); 112 if (useIt != def.uses.end()) { 113 if (overlappedRange) 114 *overlappedRange = *useIt; 115 return true; 116 } 117 return false; 118 } 119 120 /// Given a location pointing to a result, return the result number it refers 121 /// to or None if it refers to all of the results. 122 static Optional<unsigned> getResultNumberFromLoc(SMLoc loc) { 123 // Skip all of the identifier characters. 124 auto isIdentifierChar = [](char c) { 125 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' || 126 c == '-'; 127 }; 128 const char *curPtr = loc.getPointer(); 129 while (isIdentifierChar(*curPtr)) 130 ++curPtr; 131 132 // Check to see if this location indexes into the result group, via `#`. If it 133 // doesn't, we can't extract a sub result number. 134 if (*curPtr != '#') 135 return llvm::None; 136 137 // Compute the sub result number from the remaining portion of the string. 138 const char *numberStart = ++curPtr; 139 while (llvm::isDigit(*curPtr)) 140 ++curPtr; 141 StringRef numberStr(numberStart, curPtr - numberStart); 142 unsigned resultNumber = 0; 143 return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>() 144 : resultNumber; 145 } 146 147 /// Given a source location range, return the text covered by the given range. 148 /// If the range is invalid, returns None. 149 static Optional<StringRef> getTextFromRange(SMRange range) { 150 if (!range.isValid()) 151 return None; 152 const char *startPtr = range.Start.getPointer(); 153 return StringRef(startPtr, range.End.getPointer() - startPtr); 154 } 155 156 /// Given a block, return its position in its parent region. 157 static unsigned getBlockNumber(Block *block) { 158 return std::distance(block->getParent()->begin(), block->getIterator()); 159 } 160 161 /// Given a block and source location, print the source name of the block to the 162 /// given output stream. 163 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) { 164 // Try to extract a name from the source location. 165 Optional<StringRef> text = getTextFromRange(loc); 166 if (text && text->startswith("^")) { 167 os << *text; 168 return; 169 } 170 171 // Otherwise, we don't have a name so print the block number. 172 os << "<Block #" << getBlockNumber(block) << ">"; 173 } 174 static void printDefBlockName(raw_ostream &os, 175 const AsmParserState::BlockDefinition &def) { 176 printDefBlockName(os, def.block, def.definition.loc); 177 } 178 179 /// Convert the given MLIR diagnostic to the LSP form. 180 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, 181 Diagnostic &diag, 182 const lsp::URIForFile &uri) { 183 lsp::Diagnostic lspDiag; 184 lspDiag.source = "mlir"; 185 186 // Note: Right now all of the diagnostics are treated as parser issues, but 187 // some are parser and some are verifier. 188 lspDiag.category = "Parse Error"; 189 190 // Try to grab a file location for this diagnostic. 191 // TODO: For simplicity, we just grab the first one. It may be likely that we 192 // will need a more interesting heuristic here.' 193 Optional<lsp::Location> lspLocation = 194 getLocationFromLoc(sourceMgr, diag.getLocation(), &uri); 195 if (lspLocation) 196 lspDiag.range = lspLocation->range; 197 198 // Convert the severity for the diagnostic. 199 switch (diag.getSeverity()) { 200 case DiagnosticSeverity::Note: 201 llvm_unreachable("expected notes to be handled separately"); 202 case DiagnosticSeverity::Warning: 203 lspDiag.severity = lsp::DiagnosticSeverity::Warning; 204 break; 205 case DiagnosticSeverity::Error: 206 lspDiag.severity = lsp::DiagnosticSeverity::Error; 207 break; 208 case DiagnosticSeverity::Remark: 209 lspDiag.severity = lsp::DiagnosticSeverity::Information; 210 break; 211 } 212 lspDiag.message = diag.str(); 213 214 // Attach any notes to the main diagnostic as related information. 215 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; 216 for (Diagnostic ¬e : diag.getNotes()) { 217 lsp::Location noteLoc; 218 if (Optional<lsp::Location> loc = 219 getLocationFromLoc(sourceMgr, note.getLocation())) 220 noteLoc = *loc; 221 else 222 noteLoc.uri = uri; 223 relatedDiags.emplace_back(noteLoc, note.str()); 224 } 225 if (!relatedDiags.empty()) 226 lspDiag.relatedInformation = std::move(relatedDiags); 227 228 return lspDiag; 229 } 230 231 //===----------------------------------------------------------------------===// 232 // MLIRDocument 233 //===----------------------------------------------------------------------===// 234 235 namespace { 236 /// This class represents all of the information pertaining to a specific MLIR 237 /// document. 238 struct MLIRDocument { 239 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 240 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics); 241 MLIRDocument(const MLIRDocument &) = delete; 242 MLIRDocument &operator=(const MLIRDocument &) = delete; 243 244 //===--------------------------------------------------------------------===// 245 // Definitions and References 246 //===--------------------------------------------------------------------===// 247 248 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, 249 std::vector<lsp::Location> &locations); 250 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, 251 std::vector<lsp::Location> &references); 252 253 //===--------------------------------------------------------------------===// 254 // Hover 255 //===--------------------------------------------------------------------===// 256 257 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 258 const lsp::Position &hoverPos); 259 Optional<lsp::Hover> 260 buildHoverForOperation(SMRange hoverRange, 261 const AsmParserState::OperationDefinition &op); 262 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op, 263 unsigned resultStart, 264 unsigned resultEnd, SMLoc posLoc); 265 lsp::Hover buildHoverForBlock(SMRange hoverRange, 266 const AsmParserState::BlockDefinition &block); 267 lsp::Hover 268 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg, 269 const AsmParserState::BlockDefinition &block); 270 271 //===--------------------------------------------------------------------===// 272 // Document Symbols 273 //===--------------------------------------------------------------------===// 274 275 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 276 void findDocumentSymbols(Operation *op, 277 std::vector<lsp::DocumentSymbol> &symbols); 278 279 //===--------------------------------------------------------------------===// 280 // Fields 281 //===--------------------------------------------------------------------===// 282 283 /// The high level parser state used to find definitions and references within 284 /// the source file. 285 AsmParserState asmState; 286 287 /// The container for the IR parsed from the input file. 288 Block parsedIR; 289 290 /// The source manager containing the contents of the input file. 291 llvm::SourceMgr sourceMgr; 292 }; 293 } // namespace 294 295 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 296 StringRef contents, 297 std::vector<lsp::Diagnostic> &diagnostics) { 298 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { 299 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri)); 300 }); 301 302 // Try to parsed the given IR string. 303 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); 304 if (!memBuffer) { 305 lsp::Logger::error("Failed to create memory buffer for file", uri.file()); 306 return; 307 } 308 309 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 310 if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, 311 &asmState))) { 312 // If parsing failed, clear out any of the current state. 313 parsedIR.clear(); 314 asmState = AsmParserState(); 315 return; 316 } 317 } 318 319 //===----------------------------------------------------------------------===// 320 // MLIRDocument: Definitions and References 321 //===----------------------------------------------------------------------===// 322 323 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, 324 const lsp::Position &defPos, 325 std::vector<lsp::Location> &locations) { 326 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); 327 328 // Functor used to check if an SM definition contains the position. 329 auto containsPosition = [&](const AsmParserState::SMDefinition &def) { 330 if (!isDefOrUse(def, posLoc)) 331 return false; 332 locations.emplace_back(uri, sourceMgr, def.loc); 333 return true; 334 }; 335 336 // Check all definitions related to operations. 337 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 338 if (contains(op.loc, posLoc)) 339 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 340 for (const auto &result : op.resultGroups) 341 if (containsPosition(result.definition)) 342 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 343 for (const auto &symUse : op.symbolUses) { 344 if (contains(symUse, posLoc)) { 345 locations.emplace_back(uri, sourceMgr, op.loc); 346 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 347 } 348 } 349 } 350 351 // Check all definitions related to blocks. 352 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 353 if (containsPosition(block.definition)) 354 return; 355 for (const AsmParserState::SMDefinition &arg : block.arguments) 356 if (containsPosition(arg)) 357 return; 358 } 359 } 360 361 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri, 362 const lsp::Position &pos, 363 std::vector<lsp::Location> &references) { 364 // Functor used to append all of the definitions/uses of the given SM 365 // definition to the reference list. 366 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) { 367 references.emplace_back(uri, sourceMgr, def.loc); 368 for (const SMRange &use : def.uses) 369 references.emplace_back(uri, sourceMgr, use); 370 }; 371 372 SMLoc posLoc = pos.getAsSMLoc(sourceMgr); 373 374 // Check all definitions related to operations. 375 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 376 if (contains(op.loc, posLoc)) { 377 for (const auto &result : op.resultGroups) 378 appendSMDef(result.definition); 379 for (const auto &symUse : op.symbolUses) 380 if (contains(symUse, posLoc)) 381 references.emplace_back(uri, sourceMgr, symUse); 382 return; 383 } 384 for (const auto &result : op.resultGroups) 385 if (isDefOrUse(result.definition, posLoc)) 386 return appendSMDef(result.definition); 387 for (const auto &symUse : op.symbolUses) { 388 if (!contains(symUse, posLoc)) 389 continue; 390 for (const auto &symUse : op.symbolUses) 391 references.emplace_back(uri, sourceMgr, symUse); 392 return; 393 } 394 } 395 396 // Check all definitions related to blocks. 397 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 398 if (isDefOrUse(block.definition, posLoc)) 399 return appendSMDef(block.definition); 400 401 for (const AsmParserState::SMDefinition &arg : block.arguments) 402 if (isDefOrUse(arg, posLoc)) 403 return appendSMDef(arg); 404 } 405 } 406 407 //===----------------------------------------------------------------------===// 408 // MLIRDocument: Hover 409 //===----------------------------------------------------------------------===// 410 411 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri, 412 const lsp::Position &hoverPos) { 413 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr); 414 SMRange hoverRange; 415 416 // Check for Hovers on operations and results. 417 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 418 // Check if the position points at this operation. 419 if (contains(op.loc, posLoc)) 420 return buildHoverForOperation(op.loc, op); 421 422 // Check if the position points at the symbol name. 423 for (auto &use : op.symbolUses) 424 if (contains(use, posLoc)) 425 return buildHoverForOperation(use, op); 426 427 // Check if the position points at a result group. 428 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) { 429 const auto &result = op.resultGroups[i]; 430 if (!isDefOrUse(result.definition, posLoc, &hoverRange)) 431 continue; 432 433 // Get the range of results covered by the over position. 434 unsigned resultStart = result.startIndex; 435 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults() 436 : op.resultGroups[i + 1].startIndex; 437 return buildHoverForOperationResult(hoverRange, op.op, resultStart, 438 resultEnd, posLoc); 439 } 440 } 441 442 // Check to see if the hover is over a block argument. 443 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 444 if (isDefOrUse(block.definition, posLoc, &hoverRange)) 445 return buildHoverForBlock(hoverRange, block); 446 447 for (const auto &arg : llvm::enumerate(block.arguments)) { 448 if (!isDefOrUse(arg.value(), posLoc, &hoverRange)) 449 continue; 450 451 return buildHoverForBlockArgument( 452 hoverRange, block.block->getArgument(arg.index()), block); 453 } 454 } 455 return llvm::None; 456 } 457 458 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation( 459 SMRange hoverRange, const AsmParserState::OperationDefinition &op) { 460 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 461 llvm::raw_string_ostream os(hover.contents.value); 462 463 // Add the operation name to the hover. 464 os << "\"" << op.op->getName() << "\""; 465 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op)) 466 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << ""; 467 os << "\n\n"; 468 469 os << "Generic Form:\n\n```mlir\n"; 470 471 // Temporary drop the regions of this operation so that they don't get 472 // printed in the output. This helps keeps the size of the output hover 473 // small. 474 SmallVector<std::unique_ptr<Region>> regions; 475 for (Region ®ion : op.op->getRegions()) { 476 regions.emplace_back(std::make_unique<Region>()); 477 regions.back()->takeBody(region); 478 } 479 480 op.op->print( 481 os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); 482 os << "\n```\n"; 483 484 // Move the regions back to the current operation. 485 for (Region ®ion : op.op->getRegions()) 486 region.takeBody(*regions.back()); 487 488 return hover; 489 } 490 491 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange, 492 Operation *op, 493 unsigned resultStart, 494 unsigned resultEnd, 495 SMLoc posLoc) { 496 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 497 llvm::raw_string_ostream os(hover.contents.value); 498 499 // Add the parent operation name to the hover. 500 os << "Operation: \"" << op->getName() << "\"\n\n"; 501 502 // Check to see if the location points to a specific result within the 503 // group. 504 if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) { 505 if ((resultStart + *resultNumber) < resultEnd) { 506 resultStart += *resultNumber; 507 resultEnd = resultStart + 1; 508 } 509 } 510 511 // Add the range of results and their types to the hover info. 512 if ((resultStart + 1) == resultEnd) { 513 os << "Result #" << resultStart << "\n\n" 514 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n"; 515 } else { 516 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n" 517 << "Types: "; 518 llvm::interleaveComma( 519 op->getResults().slice(resultStart, resultEnd), os, 520 [&](Value result) { os << "`" << result.getType() << "`"; }); 521 } 522 523 return hover; 524 } 525 526 lsp::Hover 527 MLIRDocument::buildHoverForBlock(SMRange hoverRange, 528 const AsmParserState::BlockDefinition &block) { 529 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 530 llvm::raw_string_ostream os(hover.contents.value); 531 532 // Print the given block to the hover output stream. 533 auto printBlockToHover = [&](Block *newBlock) { 534 if (const auto *def = asmState.getBlockDef(newBlock)) 535 printDefBlockName(os, *def); 536 else 537 printDefBlockName(os, newBlock); 538 }; 539 540 // Display the parent operation, block number, predecessors, and successors. 541 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 542 << "Block #" << getBlockNumber(block.block) << "\n\n"; 543 if (!block.block->hasNoPredecessors()) { 544 os << "Predecessors: "; 545 llvm::interleaveComma(block.block->getPredecessors(), os, 546 printBlockToHover); 547 os << "\n\n"; 548 } 549 if (!block.block->hasNoSuccessors()) { 550 os << "Successors: "; 551 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover); 552 os << "\n\n"; 553 } 554 555 return hover; 556 } 557 558 lsp::Hover MLIRDocument::buildHoverForBlockArgument( 559 SMRange hoverRange, BlockArgument arg, 560 const AsmParserState::BlockDefinition &block) { 561 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 562 llvm::raw_string_ostream os(hover.contents.value); 563 564 // Display the parent operation, block, the argument number, and the type. 565 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 566 << "Block: "; 567 printDefBlockName(os, block); 568 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n" 569 << "Type: `" << arg.getType() << "`\n\n"; 570 571 return hover; 572 } 573 574 //===----------------------------------------------------------------------===// 575 // MLIRDocument: Document Symbols 576 //===----------------------------------------------------------------------===// 577 578 void MLIRDocument::findDocumentSymbols( 579 std::vector<lsp::DocumentSymbol> &symbols) { 580 for (Operation &op : parsedIR) 581 findDocumentSymbols(&op, symbols); 582 } 583 584 void MLIRDocument::findDocumentSymbols( 585 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) { 586 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols; 587 588 // Check for the source information of this operation. 589 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) { 590 // If this operation defines a symbol, record it. 591 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) { 592 symbols.emplace_back(symbol.getName(), 593 isa<FunctionOpInterface>(op) 594 ? lsp::SymbolKind::Function 595 : lsp::SymbolKind::Class, 596 lsp::Range(sourceMgr, def->scopeLoc), 597 lsp::Range(sourceMgr, def->loc)); 598 childSymbols = &symbols.back().children; 599 600 } else if (op->hasTrait<OpTrait::SymbolTable>()) { 601 // Otherwise, if this is a symbol table push an anonymous document symbol. 602 symbols.emplace_back("<" + op->getName().getStringRef() + ">", 603 lsp::SymbolKind::Namespace, 604 lsp::Range(sourceMgr, def->scopeLoc), 605 lsp::Range(sourceMgr, def->loc)); 606 childSymbols = &symbols.back().children; 607 } 608 } 609 610 // Recurse into the regions of this operation. 611 if (!op->getNumRegions()) 612 return; 613 for (Region ®ion : op->getRegions()) 614 for (Operation &childOp : region.getOps()) 615 findDocumentSymbols(&childOp, *childSymbols); 616 } 617 618 //===----------------------------------------------------------------------===// 619 // MLIRTextFileChunk 620 //===----------------------------------------------------------------------===// 621 622 namespace { 623 /// This class represents a single chunk of an MLIR text file. 624 struct MLIRTextFileChunk { 625 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset, 626 const lsp::URIForFile &uri, StringRef contents, 627 std::vector<lsp::Diagnostic> &diagnostics) 628 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {} 629 630 /// Adjust the line number of the given range to anchor at the beginning of 631 /// the file, instead of the beginning of this chunk. 632 void adjustLocForChunkOffset(lsp::Range &range) { 633 adjustLocForChunkOffset(range.start); 634 adjustLocForChunkOffset(range.end); 635 } 636 /// Adjust the line number of the given position to anchor at the beginning of 637 /// the file, instead of the beginning of this chunk. 638 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } 639 640 /// The line offset of this chunk from the beginning of the file. 641 uint64_t lineOffset; 642 /// The document referred to by this chunk. 643 MLIRDocument document; 644 }; 645 } // namespace 646 647 //===----------------------------------------------------------------------===// 648 // MLIRTextFile 649 //===----------------------------------------------------------------------===// 650 651 namespace { 652 /// This class represents a text file containing one or more MLIR documents. 653 class MLIRTextFile { 654 public: 655 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 656 int64_t version, DialectRegistry ®istry, 657 std::vector<lsp::Diagnostic> &diagnostics); 658 659 /// Return the current version of this text file. 660 int64_t getVersion() const { return version; } 661 662 //===--------------------------------------------------------------------===// 663 // LSP Queries 664 //===--------------------------------------------------------------------===// 665 666 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, 667 std::vector<lsp::Location> &locations); 668 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, 669 std::vector<lsp::Location> &references); 670 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 671 lsp::Position hoverPos); 672 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 673 674 private: 675 /// Find the MLIR document that contains the given position, and update the 676 /// position to be anchored at the start of the found chunk instead of the 677 /// beginning of the file. 678 MLIRTextFileChunk &getChunkFor(lsp::Position &pos); 679 680 /// The context used to hold the state contained by the parsed document. 681 MLIRContext context; 682 683 /// The full string contents of the file. 684 std::string contents; 685 686 /// The version of this file. 687 int64_t version; 688 689 /// The number of lines in the file. 690 int64_t totalNumLines = 0; 691 692 /// The chunks of this file. The order of these chunks is the order in which 693 /// they appear in the text file. 694 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks; 695 }; 696 } // namespace 697 698 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 699 int64_t version, DialectRegistry ®istry, 700 std::vector<lsp::Diagnostic> &diagnostics) 701 : context(registry, MLIRContext::Threading::DISABLED), 702 contents(fileContents.str()), version(version) { 703 context.allowUnregisteredDialects(); 704 705 // Split the file into separate MLIR documents. 706 // TODO: Find a way to share the split file marker with other tools. We don't 707 // want to use `splitAndProcessBuffer` here, but we do want to make sure this 708 // marker doesn't go out of sync. 709 SmallVector<StringRef, 8> subContents; 710 StringRef(contents).split(subContents, "// -----"); 711 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>( 712 context, /*lineOffset=*/0, uri, subContents.front(), diagnostics)); 713 714 uint64_t lineOffset = subContents.front().count('\n'); 715 for (StringRef docContents : llvm::drop_begin(subContents)) { 716 unsigned currentNumDiags = diagnostics.size(); 717 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri, 718 docContents, diagnostics); 719 lineOffset += docContents.count('\n'); 720 721 // Adjust locations used in diagnostics to account for the offset from the 722 // beginning of the file. 723 for (lsp::Diagnostic &diag : 724 llvm::drop_begin(diagnostics, currentNumDiags)) { 725 chunk->adjustLocForChunkOffset(diag.range); 726 727 if (!diag.relatedInformation) 728 continue; 729 for (auto &it : *diag.relatedInformation) 730 if (it.location.uri == uri) 731 chunk->adjustLocForChunkOffset(it.location.range); 732 } 733 chunks.emplace_back(std::move(chunk)); 734 } 735 totalNumLines = lineOffset; 736 } 737 738 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri, 739 lsp::Position defPos, 740 std::vector<lsp::Location> &locations) { 741 MLIRTextFileChunk &chunk = getChunkFor(defPos); 742 chunk.document.getLocationsOf(uri, defPos, locations); 743 744 // Adjust any locations within this file for the offset of this chunk. 745 if (chunk.lineOffset == 0) 746 return; 747 for (lsp::Location &loc : locations) 748 if (loc.uri == uri) 749 chunk.adjustLocForChunkOffset(loc.range); 750 } 751 752 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri, 753 lsp::Position pos, 754 std::vector<lsp::Location> &references) { 755 MLIRTextFileChunk &chunk = getChunkFor(pos); 756 chunk.document.findReferencesOf(uri, pos, references); 757 758 // Adjust any locations within this file for the offset of this chunk. 759 if (chunk.lineOffset == 0) 760 return; 761 for (lsp::Location &loc : references) 762 if (loc.uri == uri) 763 chunk.adjustLocForChunkOffset(loc.range); 764 } 765 766 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri, 767 lsp::Position hoverPos) { 768 MLIRTextFileChunk &chunk = getChunkFor(hoverPos); 769 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos); 770 771 // Adjust any locations within this file for the offset of this chunk. 772 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) 773 chunk.adjustLocForChunkOffset(*hoverInfo->range); 774 return hoverInfo; 775 } 776 777 void MLIRTextFile::findDocumentSymbols( 778 std::vector<lsp::DocumentSymbol> &symbols) { 779 if (chunks.size() == 1) 780 return chunks.front()->document.findDocumentSymbols(symbols); 781 782 // If there are multiple chunks in this file, we create top-level symbols for 783 // each chunk. 784 for (unsigned i = 0, e = chunks.size(); i < e; ++i) { 785 MLIRTextFileChunk &chunk = *chunks[i]; 786 lsp::Position startPos(chunk.lineOffset); 787 lsp::Position endPos((i == e - 1) ? totalNumLines - 1 788 : chunks[i + 1]->lineOffset); 789 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">", 790 lsp::SymbolKind::Namespace, 791 /*range=*/lsp::Range(startPos, endPos), 792 /*selectionRange=*/lsp::Range(startPos)); 793 chunk.document.findDocumentSymbols(symbol.children); 794 795 // Fixup the locations of document symbols within this chunk. 796 if (i != 0) { 797 SmallVector<lsp::DocumentSymbol *> symbolsToFix; 798 for (lsp::DocumentSymbol &childSymbol : symbol.children) 799 symbolsToFix.push_back(&childSymbol); 800 801 while (!symbolsToFix.empty()) { 802 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); 803 chunk.adjustLocForChunkOffset(symbol->range); 804 chunk.adjustLocForChunkOffset(symbol->selectionRange); 805 806 for (lsp::DocumentSymbol &childSymbol : symbol->children) 807 symbolsToFix.push_back(&childSymbol); 808 } 809 } 810 811 // Push the symbol for this chunk. 812 symbols.emplace_back(std::move(symbol)); 813 } 814 } 815 816 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { 817 if (chunks.size() == 1) 818 return *chunks.front(); 819 820 // Search for the first chunk with a greater line offset, the previous chunk 821 // is the one that contains `pos`. 822 auto it = llvm::upper_bound( 823 chunks, pos, [](const lsp::Position &pos, const auto &chunk) { 824 return static_cast<uint64_t>(pos.line) < chunk->lineOffset; 825 }); 826 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it); 827 pos.line -= chunk.lineOffset; 828 return chunk; 829 } 830 831 //===----------------------------------------------------------------------===// 832 // MLIRServer::Impl 833 //===----------------------------------------------------------------------===// 834 835 struct lsp::MLIRServer::Impl { 836 Impl(DialectRegistry ®istry) : registry(registry) {} 837 838 /// The registry containing dialects that can be recognized in parsed .mlir 839 /// files. 840 DialectRegistry ®istry; 841 842 /// The files held by the server, mapped by their URI file name. 843 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files; 844 }; 845 846 //===----------------------------------------------------------------------===// 847 // MLIRServer 848 //===----------------------------------------------------------------------===// 849 850 lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) 851 : impl(std::make_unique<Impl>(registry)) {} 852 lsp::MLIRServer::~MLIRServer() = default; 853 854 void lsp::MLIRServer::addOrUpdateDocument( 855 const URIForFile &uri, StringRef contents, int64_t version, 856 std::vector<Diagnostic> &diagnostics) { 857 impl->files[uri.file()] = std::make_unique<MLIRTextFile>( 858 uri, contents, version, impl->registry, diagnostics); 859 } 860 861 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) { 862 auto it = impl->files.find(uri.file()); 863 if (it == impl->files.end()) 864 return llvm::None; 865 866 int64_t version = it->second->getVersion(); 867 impl->files.erase(it); 868 return version; 869 } 870 871 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, 872 const Position &defPos, 873 std::vector<Location> &locations) { 874 auto fileIt = impl->files.find(uri.file()); 875 if (fileIt != impl->files.end()) 876 fileIt->second->getLocationsOf(uri, defPos, locations); 877 } 878 879 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, 880 const Position &pos, 881 std::vector<Location> &references) { 882 auto fileIt = impl->files.find(uri.file()); 883 if (fileIt != impl->files.end()) 884 fileIt->second->findReferencesOf(uri, pos, references); 885 } 886 887 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri, 888 const Position &hoverPos) { 889 auto fileIt = impl->files.find(uri.file()); 890 if (fileIt != impl->files.end()) 891 return fileIt->second->findHover(uri, hoverPos); 892 return llvm::None; 893 } 894 895 void lsp::MLIRServer::findDocumentSymbols( 896 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) { 897 auto fileIt = impl->files.find(uri.file()); 898 if (fileIt != impl->files.end()) 899 fileIt->second->findDocumentSymbols(symbols); 900 } 901