1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MLIRServer.h" 10 #include "../lsp-server-support/Logging.h" 11 #include "../lsp-server-support/Protocol.h" 12 #include "../lsp-server-support/SourceMgrUtils.h" 13 #include "mlir/IR/FunctionInterfaces.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/Parser/AsmParserState.h" 16 #include "mlir/Parser/CodeComplete.h" 17 #include "mlir/Parser/Parser.h" 18 #include "llvm/Support/SourceMgr.h" 19 20 using namespace mlir; 21 22 /// Returns a language server location from the given MLIR file location. 23 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) { 24 llvm::Expected<lsp::URIForFile> sourceURI = 25 lsp::URIForFile::fromFile(loc.getFilename()); 26 if (!sourceURI) { 27 lsp::Logger::error("Failed to create URI for file `{0}`: {1}", 28 loc.getFilename(), 29 llvm::toString(sourceURI.takeError())); 30 return llvm::None; 31 } 32 33 lsp::Position position; 34 position.line = loc.getLine() - 1; 35 position.character = loc.getColumn() ? loc.getColumn() - 1 : 0; 36 return lsp::Location{*sourceURI, lsp::Range(position)}; 37 } 38 39 /// Returns a language server location from the given MLIR location, or None if 40 /// one couldn't be created. `uri` is an optional additional filter that, when 41 /// present, is used to filter sub locations that do not share the same uri. 42 static Optional<lsp::Location> 43 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, 44 const lsp::URIForFile *uri = nullptr) { 45 Optional<lsp::Location> location; 46 loc->walk([&](Location nestedLoc) { 47 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 48 if (!fileLoc) 49 return WalkResult::advance(); 50 51 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 52 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { 53 location = *sourceLoc; 54 SMLoc loc = sourceMgr.FindLocForLineAndColumn( 55 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn()); 56 57 // Use range of potential identifier starting at location, else length 1 58 // range. 59 location->range.end.character += 1; 60 if (Optional<SMRange> range = lsp::convertTokenLocToRange(loc)) { 61 auto lineCol = sourceMgr.getLineAndColumn(range->End); 62 location->range.end.character = 63 std::max(fileLoc.getColumn() + 1, lineCol.second - 1); 64 } 65 return WalkResult::interrupt(); 66 } 67 return WalkResult::advance(); 68 }); 69 return location; 70 } 71 72 /// Collect all of the locations from the given MLIR location that are not 73 /// contained within the given URI. 74 static void collectLocationsFromLoc(Location loc, 75 std::vector<lsp::Location> &locations, 76 const lsp::URIForFile &uri) { 77 SetVector<Location> visitedLocs; 78 loc->walk([&](Location nestedLoc) { 79 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>(); 80 if (!fileLoc || !visitedLocs.insert(nestedLoc)) 81 return WalkResult::advance(); 82 83 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc); 84 if (sourceLoc && sourceLoc->uri != uri) 85 locations.push_back(*sourceLoc); 86 return WalkResult::advance(); 87 }); 88 } 89 90 /// Returns true if the given range contains the given source location. Note 91 /// that this has slightly different behavior than SMRange because it is 92 /// inclusive of the end location. 93 static bool contains(SMRange range, SMLoc loc) { 94 return range.Start.getPointer() <= loc.getPointer() && 95 loc.getPointer() <= range.End.getPointer(); 96 } 97 98 /// Returns true if the given location is contained by the definition or one of 99 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to 100 /// the range within `def` that the provided `loc` overlapped with. 101 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, 102 SMRange *overlappedRange = nullptr) { 103 // Check the main definition. 104 if (contains(def.loc, loc)) { 105 if (overlappedRange) 106 *overlappedRange = def.loc; 107 return true; 108 } 109 110 // Check the uses. 111 const auto *useIt = llvm::find_if( 112 def.uses, [&](const SMRange &range) { return contains(range, loc); }); 113 if (useIt != def.uses.end()) { 114 if (overlappedRange) 115 *overlappedRange = *useIt; 116 return true; 117 } 118 return false; 119 } 120 121 /// Given a location pointing to a result, return the result number it refers 122 /// to or None if it refers to all of the results. 123 static Optional<unsigned> getResultNumberFromLoc(SMLoc loc) { 124 // Skip all of the identifier characters. 125 auto isIdentifierChar = [](char c) { 126 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' || 127 c == '-'; 128 }; 129 const char *curPtr = loc.getPointer(); 130 while (isIdentifierChar(*curPtr)) 131 ++curPtr; 132 133 // Check to see if this location indexes into the result group, via `#`. If it 134 // doesn't, we can't extract a sub result number. 135 if (*curPtr != '#') 136 return llvm::None; 137 138 // Compute the sub result number from the remaining portion of the string. 139 const char *numberStart = ++curPtr; 140 while (llvm::isDigit(*curPtr)) 141 ++curPtr; 142 StringRef numberStr(numberStart, curPtr - numberStart); 143 unsigned resultNumber = 0; 144 return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>() 145 : resultNumber; 146 } 147 148 /// Given a source location range, return the text covered by the given range. 149 /// If the range is invalid, returns None. 150 static Optional<StringRef> getTextFromRange(SMRange range) { 151 if (!range.isValid()) 152 return None; 153 const char *startPtr = range.Start.getPointer(); 154 return StringRef(startPtr, range.End.getPointer() - startPtr); 155 } 156 157 /// Given a block, return its position in its parent region. 158 static unsigned getBlockNumber(Block *block) { 159 return std::distance(block->getParent()->begin(), block->getIterator()); 160 } 161 162 /// Given a block and source location, print the source name of the block to the 163 /// given output stream. 164 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) { 165 // Try to extract a name from the source location. 166 Optional<StringRef> text = getTextFromRange(loc); 167 if (text && text->startswith("^")) { 168 os << *text; 169 return; 170 } 171 172 // Otherwise, we don't have a name so print the block number. 173 os << "<Block #" << getBlockNumber(block) << ">"; 174 } 175 static void printDefBlockName(raw_ostream &os, 176 const AsmParserState::BlockDefinition &def) { 177 printDefBlockName(os, def.block, def.definition.loc); 178 } 179 180 /// Convert the given MLIR diagnostic to the LSP form. 181 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, 182 Diagnostic &diag, 183 const lsp::URIForFile &uri) { 184 lsp::Diagnostic lspDiag; 185 lspDiag.source = "mlir"; 186 187 // Note: Right now all of the diagnostics are treated as parser issues, but 188 // some are parser and some are verifier. 189 lspDiag.category = "Parse Error"; 190 191 // Try to grab a file location for this diagnostic. 192 // TODO: For simplicity, we just grab the first one. It may be likely that we 193 // will need a more interesting heuristic here.' 194 Optional<lsp::Location> lspLocation = 195 getLocationFromLoc(sourceMgr, diag.getLocation(), &uri); 196 if (lspLocation) 197 lspDiag.range = lspLocation->range; 198 199 // Convert the severity for the diagnostic. 200 switch (diag.getSeverity()) { 201 case DiagnosticSeverity::Note: 202 llvm_unreachable("expected notes to be handled separately"); 203 case DiagnosticSeverity::Warning: 204 lspDiag.severity = lsp::DiagnosticSeverity::Warning; 205 break; 206 case DiagnosticSeverity::Error: 207 lspDiag.severity = lsp::DiagnosticSeverity::Error; 208 break; 209 case DiagnosticSeverity::Remark: 210 lspDiag.severity = lsp::DiagnosticSeverity::Information; 211 break; 212 } 213 lspDiag.message = diag.str(); 214 215 // Attach any notes to the main diagnostic as related information. 216 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; 217 for (Diagnostic ¬e : diag.getNotes()) { 218 lsp::Location noteLoc; 219 if (Optional<lsp::Location> loc = 220 getLocationFromLoc(sourceMgr, note.getLocation())) 221 noteLoc = *loc; 222 else 223 noteLoc.uri = uri; 224 relatedDiags.emplace_back(noteLoc, note.str()); 225 } 226 if (!relatedDiags.empty()) 227 lspDiag.relatedInformation = std::move(relatedDiags); 228 229 return lspDiag; 230 } 231 232 //===----------------------------------------------------------------------===// 233 // MLIRDocument 234 //===----------------------------------------------------------------------===// 235 236 namespace { 237 /// This class represents all of the information pertaining to a specific MLIR 238 /// document. 239 struct MLIRDocument { 240 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 241 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics); 242 MLIRDocument(const MLIRDocument &) = delete; 243 MLIRDocument &operator=(const MLIRDocument &) = delete; 244 245 //===--------------------------------------------------------------------===// 246 // Definitions and References 247 //===--------------------------------------------------------------------===// 248 249 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, 250 std::vector<lsp::Location> &locations); 251 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, 252 std::vector<lsp::Location> &references); 253 254 //===--------------------------------------------------------------------===// 255 // Hover 256 //===--------------------------------------------------------------------===// 257 258 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 259 const lsp::Position &hoverPos); 260 Optional<lsp::Hover> 261 buildHoverForOperation(SMRange hoverRange, 262 const AsmParserState::OperationDefinition &op); 263 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op, 264 unsigned resultStart, 265 unsigned resultEnd, SMLoc posLoc); 266 lsp::Hover buildHoverForBlock(SMRange hoverRange, 267 const AsmParserState::BlockDefinition &block); 268 lsp::Hover 269 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg, 270 const AsmParserState::BlockDefinition &block); 271 272 //===--------------------------------------------------------------------===// 273 // Document Symbols 274 //===--------------------------------------------------------------------===// 275 276 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 277 void findDocumentSymbols(Operation *op, 278 std::vector<lsp::DocumentSymbol> &symbols); 279 280 //===--------------------------------------------------------------------===// 281 // Code Completion 282 //===--------------------------------------------------------------------===// 283 284 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, 285 const lsp::Position &completePos, 286 const DialectRegistry ®istry); 287 288 //===--------------------------------------------------------------------===// 289 // Fields 290 //===--------------------------------------------------------------------===// 291 292 /// The high level parser state used to find definitions and references within 293 /// the source file. 294 AsmParserState asmState; 295 296 /// The container for the IR parsed from the input file. 297 Block parsedIR; 298 299 /// The source manager containing the contents of the input file. 300 llvm::SourceMgr sourceMgr; 301 }; 302 } // namespace 303 304 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, 305 StringRef contents, 306 std::vector<lsp::Diagnostic> &diagnostics) { 307 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { 308 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri)); 309 }); 310 311 // Try to parsed the given IR string. 312 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); 313 if (!memBuffer) { 314 lsp::Logger::error("Failed to create memory buffer for file", uri.file()); 315 return; 316 } 317 318 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 319 if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, 320 &asmState))) { 321 // If parsing failed, clear out any of the current state. 322 parsedIR.clear(); 323 asmState = AsmParserState(); 324 return; 325 } 326 } 327 328 //===----------------------------------------------------------------------===// 329 // MLIRDocument: Definitions and References 330 //===----------------------------------------------------------------------===// 331 332 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, 333 const lsp::Position &defPos, 334 std::vector<lsp::Location> &locations) { 335 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); 336 337 // Functor used to check if an SM definition contains the position. 338 auto containsPosition = [&](const AsmParserState::SMDefinition &def) { 339 if (!isDefOrUse(def, posLoc)) 340 return false; 341 locations.emplace_back(uri, sourceMgr, def.loc); 342 return true; 343 }; 344 345 // Check all definitions related to operations. 346 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 347 if (contains(op.loc, posLoc)) 348 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 349 for (const auto &result : op.resultGroups) 350 if (containsPosition(result.definition)) 351 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 352 for (const auto &symUse : op.symbolUses) { 353 if (contains(symUse, posLoc)) { 354 locations.emplace_back(uri, sourceMgr, op.loc); 355 return collectLocationsFromLoc(op.op->getLoc(), locations, uri); 356 } 357 } 358 } 359 360 // Check all definitions related to blocks. 361 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 362 if (containsPosition(block.definition)) 363 return; 364 for (const AsmParserState::SMDefinition &arg : block.arguments) 365 if (containsPosition(arg)) 366 return; 367 } 368 } 369 370 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri, 371 const lsp::Position &pos, 372 std::vector<lsp::Location> &references) { 373 // Functor used to append all of the definitions/uses of the given SM 374 // definition to the reference list. 375 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) { 376 references.emplace_back(uri, sourceMgr, def.loc); 377 for (const SMRange &use : def.uses) 378 references.emplace_back(uri, sourceMgr, use); 379 }; 380 381 SMLoc posLoc = pos.getAsSMLoc(sourceMgr); 382 383 // Check all definitions related to operations. 384 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 385 if (contains(op.loc, posLoc)) { 386 for (const auto &result : op.resultGroups) 387 appendSMDef(result.definition); 388 for (const auto &symUse : op.symbolUses) 389 if (contains(symUse, posLoc)) 390 references.emplace_back(uri, sourceMgr, symUse); 391 return; 392 } 393 for (const auto &result : op.resultGroups) 394 if (isDefOrUse(result.definition, posLoc)) 395 return appendSMDef(result.definition); 396 for (const auto &symUse : op.symbolUses) { 397 if (!contains(symUse, posLoc)) 398 continue; 399 for (const auto &symUse : op.symbolUses) 400 references.emplace_back(uri, sourceMgr, symUse); 401 return; 402 } 403 } 404 405 // Check all definitions related to blocks. 406 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 407 if (isDefOrUse(block.definition, posLoc)) 408 return appendSMDef(block.definition); 409 410 for (const AsmParserState::SMDefinition &arg : block.arguments) 411 if (isDefOrUse(arg, posLoc)) 412 return appendSMDef(arg); 413 } 414 } 415 416 //===----------------------------------------------------------------------===// 417 // MLIRDocument: Hover 418 //===----------------------------------------------------------------------===// 419 420 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri, 421 const lsp::Position &hoverPos) { 422 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr); 423 SMRange hoverRange; 424 425 // Check for Hovers on operations and results. 426 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { 427 // Check if the position points at this operation. 428 if (contains(op.loc, posLoc)) 429 return buildHoverForOperation(op.loc, op); 430 431 // Check if the position points at the symbol name. 432 for (auto &use : op.symbolUses) 433 if (contains(use, posLoc)) 434 return buildHoverForOperation(use, op); 435 436 // Check if the position points at a result group. 437 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) { 438 const auto &result = op.resultGroups[i]; 439 if (!isDefOrUse(result.definition, posLoc, &hoverRange)) 440 continue; 441 442 // Get the range of results covered by the over position. 443 unsigned resultStart = result.startIndex; 444 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults() 445 : op.resultGroups[i + 1].startIndex; 446 return buildHoverForOperationResult(hoverRange, op.op, resultStart, 447 resultEnd, posLoc); 448 } 449 } 450 451 // Check to see if the hover is over a block argument. 452 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { 453 if (isDefOrUse(block.definition, posLoc, &hoverRange)) 454 return buildHoverForBlock(hoverRange, block); 455 456 for (const auto &arg : llvm::enumerate(block.arguments)) { 457 if (!isDefOrUse(arg.value(), posLoc, &hoverRange)) 458 continue; 459 460 return buildHoverForBlockArgument( 461 hoverRange, block.block->getArgument(arg.index()), block); 462 } 463 } 464 return llvm::None; 465 } 466 467 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation( 468 SMRange hoverRange, const AsmParserState::OperationDefinition &op) { 469 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 470 llvm::raw_string_ostream os(hover.contents.value); 471 472 // Add the operation name to the hover. 473 os << "\"" << op.op->getName() << "\""; 474 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op)) 475 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << ""; 476 os << "\n\n"; 477 478 os << "Generic Form:\n\n```mlir\n"; 479 480 // Temporary drop the regions of this operation so that they don't get 481 // printed in the output. This helps keeps the size of the output hover 482 // small. 483 SmallVector<std::unique_ptr<Region>> regions; 484 for (Region ®ion : op.op->getRegions()) { 485 regions.emplace_back(std::make_unique<Region>()); 486 regions.back()->takeBody(region); 487 } 488 489 op.op->print( 490 os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); 491 os << "\n```\n"; 492 493 // Move the regions back to the current operation. 494 for (Region ®ion : op.op->getRegions()) 495 region.takeBody(*regions.back()); 496 497 return hover; 498 } 499 500 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange, 501 Operation *op, 502 unsigned resultStart, 503 unsigned resultEnd, 504 SMLoc posLoc) { 505 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 506 llvm::raw_string_ostream os(hover.contents.value); 507 508 // Add the parent operation name to the hover. 509 os << "Operation: \"" << op->getName() << "\"\n\n"; 510 511 // Check to see if the location points to a specific result within the 512 // group. 513 if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) { 514 if ((resultStart + *resultNumber) < resultEnd) { 515 resultStart += *resultNumber; 516 resultEnd = resultStart + 1; 517 } 518 } 519 520 // Add the range of results and their types to the hover info. 521 if ((resultStart + 1) == resultEnd) { 522 os << "Result #" << resultStart << "\n\n" 523 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n"; 524 } else { 525 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n" 526 << "Types: "; 527 llvm::interleaveComma( 528 op->getResults().slice(resultStart, resultEnd), os, 529 [&](Value result) { os << "`" << result.getType() << "`"; }); 530 } 531 532 return hover; 533 } 534 535 lsp::Hover 536 MLIRDocument::buildHoverForBlock(SMRange hoverRange, 537 const AsmParserState::BlockDefinition &block) { 538 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 539 llvm::raw_string_ostream os(hover.contents.value); 540 541 // Print the given block to the hover output stream. 542 auto printBlockToHover = [&](Block *newBlock) { 543 if (const auto *def = asmState.getBlockDef(newBlock)) 544 printDefBlockName(os, *def); 545 else 546 printDefBlockName(os, newBlock); 547 }; 548 549 // Display the parent operation, block number, predecessors, and successors. 550 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 551 << "Block #" << getBlockNumber(block.block) << "\n\n"; 552 if (!block.block->hasNoPredecessors()) { 553 os << "Predecessors: "; 554 llvm::interleaveComma(block.block->getPredecessors(), os, 555 printBlockToHover); 556 os << "\n\n"; 557 } 558 if (!block.block->hasNoSuccessors()) { 559 os << "Successors: "; 560 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover); 561 os << "\n\n"; 562 } 563 564 return hover; 565 } 566 567 lsp::Hover MLIRDocument::buildHoverForBlockArgument( 568 SMRange hoverRange, BlockArgument arg, 569 const AsmParserState::BlockDefinition &block) { 570 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 571 llvm::raw_string_ostream os(hover.contents.value); 572 573 // Display the parent operation, block, the argument number, and the type. 574 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" 575 << "Block: "; 576 printDefBlockName(os, block); 577 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n" 578 << "Type: `" << arg.getType() << "`\n\n"; 579 580 return hover; 581 } 582 583 //===----------------------------------------------------------------------===// 584 // MLIRDocument: Document Symbols 585 //===----------------------------------------------------------------------===// 586 587 void MLIRDocument::findDocumentSymbols( 588 std::vector<lsp::DocumentSymbol> &symbols) { 589 for (Operation &op : parsedIR) 590 findDocumentSymbols(&op, symbols); 591 } 592 593 void MLIRDocument::findDocumentSymbols( 594 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) { 595 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols; 596 597 // Check for the source information of this operation. 598 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) { 599 // If this operation defines a symbol, record it. 600 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) { 601 symbols.emplace_back(symbol.getName(), 602 isa<FunctionOpInterface>(op) 603 ? lsp::SymbolKind::Function 604 : lsp::SymbolKind::Class, 605 lsp::Range(sourceMgr, def->scopeLoc), 606 lsp::Range(sourceMgr, def->loc)); 607 childSymbols = &symbols.back().children; 608 609 } else if (op->hasTrait<OpTrait::SymbolTable>()) { 610 // Otherwise, if this is a symbol table push an anonymous document symbol. 611 symbols.emplace_back("<" + op->getName().getStringRef() + ">", 612 lsp::SymbolKind::Namespace, 613 lsp::Range(sourceMgr, def->scopeLoc), 614 lsp::Range(sourceMgr, def->loc)); 615 childSymbols = &symbols.back().children; 616 } 617 } 618 619 // Recurse into the regions of this operation. 620 if (!op->getNumRegions()) 621 return; 622 for (Region ®ion : op->getRegions()) 623 for (Operation &childOp : region.getOps()) 624 findDocumentSymbols(&childOp, *childSymbols); 625 } 626 627 //===----------------------------------------------------------------------===// 628 // MLIRDocument: Code Completion 629 //===----------------------------------------------------------------------===// 630 631 namespace { 632 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { 633 public: 634 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList, 635 MLIRContext *ctx) 636 : AsmParserCodeCompleteContext(completeLoc), 637 completionList(completionList), ctx(ctx) {} 638 639 /// Signal code completion for a dialect name, with an optional prefix. 640 void completeDialectName(StringRef prefix) final { 641 for (StringRef dialect : ctx->getAvailableDialects()) { 642 lsp::CompletionItem item(prefix + dialect, 643 lsp::CompletionItemKind::Module, 644 /*sortText=*/"3"); 645 item.detail = "dialect"; 646 completionList.items.emplace_back(item); 647 } 648 } 649 using AsmParserCodeCompleteContext::completeDialectName; 650 651 /// Signal code completion for an operation name within the given dialect. 652 void completeOperationName(StringRef dialectName) final { 653 Dialect *dialect = ctx->getOrLoadDialect(dialectName); 654 if (!dialect) 655 return; 656 657 for (const auto &op : ctx->getRegisteredOperations()) { 658 if (&op.getDialect() != dialect) 659 continue; 660 661 lsp::CompletionItem item( 662 op.getStringRef().drop_front(dialectName.size() + 1), 663 lsp::CompletionItemKind::Field, 664 /*sortText=*/"1"); 665 item.detail = "operation"; 666 completionList.items.emplace_back(item); 667 } 668 } 669 670 /// Append the given SSA value as a code completion result for SSA value 671 /// completions. 672 void appendSSAValueCompletion(StringRef name, std::string typeData) final { 673 // Check if we need to insert the `%` or not. 674 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%'; 675 676 lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable); 677 if (stripPrefix) 678 item.insertText = name.drop_front(1).str(); 679 item.detail = std::move(typeData); 680 completionList.items.emplace_back(item); 681 } 682 683 /// Append the given block as a code completion result for block name 684 /// completions. 685 void appendBlockCompletion(StringRef name) final { 686 // Check if we need to insert the `^` or not. 687 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^'; 688 689 lsp::CompletionItem item(name, lsp::CompletionItemKind::Field); 690 if (stripPrefix) 691 item.insertText = name.drop_front(1).str(); 692 completionList.items.emplace_back(item); 693 } 694 695 /// Signal a completion for the given expected token. 696 void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final { 697 for (StringRef token : tokens) { 698 lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword, 699 /*sortText=*/"0"); 700 item.detail = optional ? "optional" : ""; 701 completionList.items.emplace_back(item); 702 } 703 } 704 705 /// Signal a completion for an attribute. 706 void completeAttribute(const llvm::StringMap<Attribute> &aliases) override { 707 appendSimpleCompletions({"affine_set", "affine_map", "dense", "false", 708 "loc", "opaque", "sparse", "true", "unit"}, 709 lsp::CompletionItemKind::Field, 710 /*sortText=*/"1"); 711 712 completeDialectName("#"); 713 completeAliases(aliases, "#"); 714 } 715 void completeDialectAttributeOrAlias( 716 const llvm::StringMap<Attribute> &aliases) override { 717 completeDialectName(); 718 completeAliases(aliases); 719 } 720 721 /// Signal a completion for a type. 722 void completeType(const llvm::StringMap<Type> &aliases) override { 723 // Handle the various builtin types. 724 appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector", 725 "bf16", "f16", "f32", "f64", "f80", "f128", 726 "index", "none"}, 727 lsp::CompletionItemKind::Field, 728 /*sortText=*/"1"); 729 730 // Handle the builtin integer types. 731 for (StringRef type : {"i", "si", "ui"}) { 732 lsp::CompletionItem item(type + "<N>", lsp::CompletionItemKind::Field, 733 /*sortText=*/"1"); 734 item.insertText = type.str(); 735 completionList.items.emplace_back(item); 736 } 737 738 // Insert completions for dialect types and aliases. 739 completeDialectName("!"); 740 completeAliases(aliases, "!"); 741 } 742 void 743 completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override { 744 completeDialectName(); 745 completeAliases(aliases); 746 } 747 748 /// Add completion results for the given set of aliases. 749 template <typename T> 750 void completeAliases(const llvm::StringMap<T> &aliases, 751 StringRef prefix = "") { 752 for (const auto &alias : aliases) { 753 lsp::CompletionItem item(prefix + alias.getKey(), 754 lsp::CompletionItemKind::Field, 755 /*sortText=*/"2"); 756 llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue(); 757 completionList.items.emplace_back(item); 758 } 759 } 760 761 /// Add a set of simple completions that all have the same kind. 762 void appendSimpleCompletions(ArrayRef<StringRef> completions, 763 lsp::CompletionItemKind kind, 764 StringRef sortText = "") { 765 for (StringRef completion : completions) 766 completionList.items.emplace_back(completion, kind, sortText); 767 } 768 769 private: 770 lsp::CompletionList &completionList; 771 MLIRContext *ctx; 772 }; 773 } // namespace 774 775 lsp::CompletionList 776 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri, 777 const lsp::Position &completePos, 778 const DialectRegistry ®istry) { 779 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr); 780 if (!posLoc.isValid()) 781 return lsp::CompletionList(); 782 783 // To perform code completion, we run another parse of the module with the 784 // code completion context provided. 785 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED); 786 tmpContext.allowUnregisteredDialects(); 787 lsp::CompletionList completionList; 788 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList, 789 &tmpContext); 790 791 Block tmpIR; 792 AsmParserState tmpState; 793 (void)parseSourceFile(sourceMgr, &tmpIR, &tmpContext, 794 /*sourceFileLoc=*/nullptr, &tmpState, 795 &lspCompleteContext); 796 return completionList; 797 } 798 799 //===----------------------------------------------------------------------===// 800 // MLIRTextFileChunk 801 //===----------------------------------------------------------------------===// 802 803 namespace { 804 /// This class represents a single chunk of an MLIR text file. 805 struct MLIRTextFileChunk { 806 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset, 807 const lsp::URIForFile &uri, StringRef contents, 808 std::vector<lsp::Diagnostic> &diagnostics) 809 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {} 810 811 /// Adjust the line number of the given range to anchor at the beginning of 812 /// the file, instead of the beginning of this chunk. 813 void adjustLocForChunkOffset(lsp::Range &range) { 814 adjustLocForChunkOffset(range.start); 815 adjustLocForChunkOffset(range.end); 816 } 817 /// Adjust the line number of the given position to anchor at the beginning of 818 /// the file, instead of the beginning of this chunk. 819 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } 820 821 /// The line offset of this chunk from the beginning of the file. 822 uint64_t lineOffset; 823 /// The document referred to by this chunk. 824 MLIRDocument document; 825 }; 826 } // namespace 827 828 //===----------------------------------------------------------------------===// 829 // MLIRTextFile 830 //===----------------------------------------------------------------------===// 831 832 namespace { 833 /// This class represents a text file containing one or more MLIR documents. 834 class MLIRTextFile { 835 public: 836 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 837 int64_t version, DialectRegistry ®istry, 838 std::vector<lsp::Diagnostic> &diagnostics); 839 840 /// Return the current version of this text file. 841 int64_t getVersion() const { return version; } 842 843 //===--------------------------------------------------------------------===// 844 // LSP Queries 845 //===--------------------------------------------------------------------===// 846 847 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, 848 std::vector<lsp::Location> &locations); 849 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, 850 std::vector<lsp::Location> &references); 851 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 852 lsp::Position hoverPos); 853 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 854 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, 855 lsp::Position completePos); 856 857 private: 858 /// Find the MLIR document that contains the given position, and update the 859 /// position to be anchored at the start of the found chunk instead of the 860 /// beginning of the file. 861 MLIRTextFileChunk &getChunkFor(lsp::Position &pos); 862 863 /// The context used to hold the state contained by the parsed document. 864 MLIRContext context; 865 866 /// The full string contents of the file. 867 std::string contents; 868 869 /// The version of this file. 870 int64_t version; 871 872 /// The number of lines in the file. 873 int64_t totalNumLines = 0; 874 875 /// The chunks of this file. The order of these chunks is the order in which 876 /// they appear in the text file. 877 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks; 878 }; 879 } // namespace 880 881 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, 882 int64_t version, DialectRegistry ®istry, 883 std::vector<lsp::Diagnostic> &diagnostics) 884 : context(registry, MLIRContext::Threading::DISABLED), 885 contents(fileContents.str()), version(version) { 886 context.allowUnregisteredDialects(); 887 888 // Split the file into separate MLIR documents. 889 // TODO: Find a way to share the split file marker with other tools. We don't 890 // want to use `splitAndProcessBuffer` here, but we do want to make sure this 891 // marker doesn't go out of sync. 892 SmallVector<StringRef, 8> subContents; 893 StringRef(contents).split(subContents, "// -----"); 894 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>( 895 context, /*lineOffset=*/0, uri, subContents.front(), diagnostics)); 896 897 uint64_t lineOffset = subContents.front().count('\n'); 898 for (StringRef docContents : llvm::drop_begin(subContents)) { 899 unsigned currentNumDiags = diagnostics.size(); 900 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri, 901 docContents, diagnostics); 902 lineOffset += docContents.count('\n'); 903 904 // Adjust locations used in diagnostics to account for the offset from the 905 // beginning of the file. 906 for (lsp::Diagnostic &diag : 907 llvm::drop_begin(diagnostics, currentNumDiags)) { 908 chunk->adjustLocForChunkOffset(diag.range); 909 910 if (!diag.relatedInformation) 911 continue; 912 for (auto &it : *diag.relatedInformation) 913 if (it.location.uri == uri) 914 chunk->adjustLocForChunkOffset(it.location.range); 915 } 916 chunks.emplace_back(std::move(chunk)); 917 } 918 totalNumLines = lineOffset; 919 } 920 921 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri, 922 lsp::Position defPos, 923 std::vector<lsp::Location> &locations) { 924 MLIRTextFileChunk &chunk = getChunkFor(defPos); 925 chunk.document.getLocationsOf(uri, defPos, locations); 926 927 // Adjust any locations within this file for the offset of this chunk. 928 if (chunk.lineOffset == 0) 929 return; 930 for (lsp::Location &loc : locations) 931 if (loc.uri == uri) 932 chunk.adjustLocForChunkOffset(loc.range); 933 } 934 935 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri, 936 lsp::Position pos, 937 std::vector<lsp::Location> &references) { 938 MLIRTextFileChunk &chunk = getChunkFor(pos); 939 chunk.document.findReferencesOf(uri, pos, references); 940 941 // Adjust any locations within this file for the offset of this chunk. 942 if (chunk.lineOffset == 0) 943 return; 944 for (lsp::Location &loc : references) 945 if (loc.uri == uri) 946 chunk.adjustLocForChunkOffset(loc.range); 947 } 948 949 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri, 950 lsp::Position hoverPos) { 951 MLIRTextFileChunk &chunk = getChunkFor(hoverPos); 952 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos); 953 954 // Adjust any locations within this file for the offset of this chunk. 955 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) 956 chunk.adjustLocForChunkOffset(*hoverInfo->range); 957 return hoverInfo; 958 } 959 960 void MLIRTextFile::findDocumentSymbols( 961 std::vector<lsp::DocumentSymbol> &symbols) { 962 if (chunks.size() == 1) 963 return chunks.front()->document.findDocumentSymbols(symbols); 964 965 // If there are multiple chunks in this file, we create top-level symbols for 966 // each chunk. 967 for (unsigned i = 0, e = chunks.size(); i < e; ++i) { 968 MLIRTextFileChunk &chunk = *chunks[i]; 969 lsp::Position startPos(chunk.lineOffset); 970 lsp::Position endPos((i == e - 1) ? totalNumLines - 1 971 : chunks[i + 1]->lineOffset); 972 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">", 973 lsp::SymbolKind::Namespace, 974 /*range=*/lsp::Range(startPos, endPos), 975 /*selectionRange=*/lsp::Range(startPos)); 976 chunk.document.findDocumentSymbols(symbol.children); 977 978 // Fixup the locations of document symbols within this chunk. 979 if (i != 0) { 980 SmallVector<lsp::DocumentSymbol *> symbolsToFix; 981 for (lsp::DocumentSymbol &childSymbol : symbol.children) 982 symbolsToFix.push_back(&childSymbol); 983 984 while (!symbolsToFix.empty()) { 985 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); 986 chunk.adjustLocForChunkOffset(symbol->range); 987 chunk.adjustLocForChunkOffset(symbol->selectionRange); 988 989 for (lsp::DocumentSymbol &childSymbol : symbol->children) 990 symbolsToFix.push_back(&childSymbol); 991 } 992 } 993 994 // Push the symbol for this chunk. 995 symbols.emplace_back(std::move(symbol)); 996 } 997 } 998 999 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri, 1000 lsp::Position completePos) { 1001 MLIRTextFileChunk &chunk = getChunkFor(completePos); 1002 lsp::CompletionList completionList = chunk.document.getCodeCompletion( 1003 uri, completePos, context.getDialectRegistry()); 1004 1005 // Adjust any completion locations. 1006 for (lsp::CompletionItem &item : completionList.items) { 1007 if (item.textEdit) 1008 chunk.adjustLocForChunkOffset(item.textEdit->range); 1009 for (lsp::TextEdit &edit : item.additionalTextEdits) 1010 chunk.adjustLocForChunkOffset(edit.range); 1011 } 1012 return completionList; 1013 } 1014 1015 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { 1016 if (chunks.size() == 1) 1017 return *chunks.front(); 1018 1019 // Search for the first chunk with a greater line offset, the previous chunk 1020 // is the one that contains `pos`. 1021 auto it = llvm::upper_bound( 1022 chunks, pos, [](const lsp::Position &pos, const auto &chunk) { 1023 return static_cast<uint64_t>(pos.line) < chunk->lineOffset; 1024 }); 1025 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it); 1026 pos.line -= chunk.lineOffset; 1027 return chunk; 1028 } 1029 1030 //===----------------------------------------------------------------------===// 1031 // MLIRServer::Impl 1032 //===----------------------------------------------------------------------===// 1033 1034 struct lsp::MLIRServer::Impl { 1035 Impl(DialectRegistry ®istry) : registry(registry) {} 1036 1037 /// The registry containing dialects that can be recognized in parsed .mlir 1038 /// files. 1039 DialectRegistry ®istry; 1040 1041 /// The files held by the server, mapped by their URI file name. 1042 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files; 1043 }; 1044 1045 //===----------------------------------------------------------------------===// 1046 // MLIRServer 1047 //===----------------------------------------------------------------------===// 1048 1049 lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) 1050 : impl(std::make_unique<Impl>(registry)) {} 1051 lsp::MLIRServer::~MLIRServer() = default; 1052 1053 void lsp::MLIRServer::addOrUpdateDocument( 1054 const URIForFile &uri, StringRef contents, int64_t version, 1055 std::vector<Diagnostic> &diagnostics) { 1056 impl->files[uri.file()] = std::make_unique<MLIRTextFile>( 1057 uri, contents, version, impl->registry, diagnostics); 1058 } 1059 1060 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) { 1061 auto it = impl->files.find(uri.file()); 1062 if (it == impl->files.end()) 1063 return llvm::None; 1064 1065 int64_t version = it->second->getVersion(); 1066 impl->files.erase(it); 1067 return version; 1068 } 1069 1070 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, 1071 const Position &defPos, 1072 std::vector<Location> &locations) { 1073 auto fileIt = impl->files.find(uri.file()); 1074 if (fileIt != impl->files.end()) 1075 fileIt->second->getLocationsOf(uri, defPos, locations); 1076 } 1077 1078 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, 1079 const Position &pos, 1080 std::vector<Location> &references) { 1081 auto fileIt = impl->files.find(uri.file()); 1082 if (fileIt != impl->files.end()) 1083 fileIt->second->findReferencesOf(uri, pos, references); 1084 } 1085 1086 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri, 1087 const Position &hoverPos) { 1088 auto fileIt = impl->files.find(uri.file()); 1089 if (fileIt != impl->files.end()) 1090 return fileIt->second->findHover(uri, hoverPos); 1091 return llvm::None; 1092 } 1093 1094 void lsp::MLIRServer::findDocumentSymbols( 1095 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) { 1096 auto fileIt = impl->files.find(uri.file()); 1097 if (fileIt != impl->files.end()) 1098 fileIt->second->findDocumentSymbols(symbols); 1099 } 1100 1101 lsp::CompletionList 1102 lsp::MLIRServer::getCodeCompletion(const URIForFile &uri, 1103 const Position &completePos) { 1104 auto fileIt = impl->files.find(uri.file()); 1105 if (fileIt != impl->files.end()) 1106 return fileIt->second->getCodeCompletion(uri, completePos); 1107 return CompletionList(); 1108 } 1109