1 //===- PDLLServer.cpp - PDLL Language Server ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PDLLServer.h" 10 11 #include "../lsp-server-support/Logging.h" 12 #include "../lsp-server-support/Protocol.h" 13 #include "CompilationDatabase.h" 14 #include "mlir/Tools/PDLL/AST/Context.h" 15 #include "mlir/Tools/PDLL/AST/Nodes.h" 16 #include "mlir/Tools/PDLL/AST/Types.h" 17 #include "mlir/Tools/PDLL/ODS/Constraint.h" 18 #include "mlir/Tools/PDLL/ODS/Context.h" 19 #include "mlir/Tools/PDLL/ODS/Dialect.h" 20 #include "mlir/Tools/PDLL/ODS/Operation.h" 21 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 22 #include "mlir/Tools/PDLL/Parser/Parser.h" 23 #include "llvm/ADT/IntervalMap.h" 24 #include "llvm/ADT/StringMap.h" 25 #include "llvm/ADT/StringSet.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FileSystem.h" 28 #include "llvm/Support/Path.h" 29 30 using namespace mlir; 31 using namespace mlir::pdll; 32 33 /// Returns a language server uri for the given source location. `mainFileURI` 34 /// corresponds to the uri for the main file of the source manager. 35 static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, 36 const lsp::URIForFile &mainFileURI) { 37 int bufferId = mgr.FindBufferContainingLoc(loc.Start); 38 if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID())) 39 return mainFileURI; 40 llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile( 41 mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); 42 if (fileForLoc) 43 return *fileForLoc; 44 lsp::Logger::error("Failed to create URI for include file: {0}", 45 llvm::toString(fileForLoc.takeError())); 46 return mainFileURI; 47 } 48 49 /// Returns true if the given location is in the main file of the source 50 /// manager. 51 static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) { 52 return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID(); 53 } 54 55 /// Returns a language server location from the given source range. 56 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, 57 const lsp::URIForFile &uri) { 58 return lsp::Location(getURIFromLoc(mgr, range, uri), lsp::Range(mgr, range)); 59 } 60 61 /// Convert the given MLIR diagnostic to the LSP form. 62 static Optional<lsp::Diagnostic> 63 getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, 64 const lsp::URIForFile &uri) { 65 lsp::Diagnostic lspDiag; 66 lspDiag.source = "pdll"; 67 68 // FIXME: Right now all of the diagnostics are treated as parser issues, but 69 // some are parser and some are verifier. 70 lspDiag.category = "Parse Error"; 71 72 // Try to grab a file location for this diagnostic. 73 lsp::Location loc = getLocationFromLoc(sourceMgr, diag.getLocation(), uri); 74 lspDiag.range = loc.range; 75 76 // Skip diagnostics that weren't emitted within the main file. 77 if (loc.uri != uri) 78 return llvm::None; 79 80 // Convert the severity for the diagnostic. 81 switch (diag.getSeverity()) { 82 case ast::Diagnostic::Severity::DK_Note: 83 llvm_unreachable("expected notes to be handled separately"); 84 case ast::Diagnostic::Severity::DK_Warning: 85 lspDiag.severity = lsp::DiagnosticSeverity::Warning; 86 break; 87 case ast::Diagnostic::Severity::DK_Error: 88 lspDiag.severity = lsp::DiagnosticSeverity::Error; 89 break; 90 case ast::Diagnostic::Severity::DK_Remark: 91 lspDiag.severity = lsp::DiagnosticSeverity::Information; 92 break; 93 } 94 lspDiag.message = diag.getMessage().str(); 95 96 // Attach any notes to the main diagnostic as related information. 97 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; 98 for (const ast::Diagnostic ¬e : diag.getNotes()) { 99 relatedDiags.emplace_back( 100 getLocationFromLoc(sourceMgr, note.getLocation(), uri), 101 note.getMessage().str()); 102 } 103 if (!relatedDiags.empty()) 104 lspDiag.relatedInformation = std::move(relatedDiags); 105 106 return lspDiag; 107 } 108 109 //===----------------------------------------------------------------------===// 110 // PDLLInclude 111 //===----------------------------------------------------------------------===// 112 113 namespace { 114 /// This class represents a single include within a root file. 115 struct PDLLInclude { 116 PDLLInclude(const lsp::URIForFile &uri, const lsp::Range &range) 117 : uri(uri), range(range) {} 118 119 /// The URI of the file that is included. 120 lsp::URIForFile uri; 121 122 /// The range of the include directive. 123 lsp::Range range; 124 }; 125 } // namespace 126 127 //===----------------------------------------------------------------------===// 128 // PDLIndex 129 //===----------------------------------------------------------------------===// 130 131 namespace { 132 struct PDLIndexSymbol { 133 explicit PDLIndexSymbol(const ast::Decl *definition) 134 : definition(definition) {} 135 explicit PDLIndexSymbol(const ods::Operation *definition) 136 : definition(definition) {} 137 138 /// Return the location of the definition of this symbol. 139 SMRange getDefLoc() const { 140 if (const ast::Decl *decl = definition.dyn_cast<const ast::Decl *>()) { 141 const ast::Name *declName = decl->getName(); 142 return declName ? declName->getLoc() : decl->getLoc(); 143 } 144 return definition.get<const ods::Operation *>()->getLoc(); 145 } 146 147 /// The main definition of the symbol. 148 PointerUnion<const ast::Decl *, const ods::Operation *> definition; 149 /// The set of references to the symbol. 150 std::vector<SMRange> references; 151 }; 152 153 /// This class provides an index for definitions/uses within a PDL document. 154 /// It provides efficient lookup of a definition given an input source range. 155 class PDLIndex { 156 public: 157 PDLIndex() : intervalMap(allocator) {} 158 159 /// Initialize the index with the given ast::Module. 160 void initialize(const ast::Module &module, const ods::Context &odsContext); 161 162 /// Lookup a symbol for the given location. Returns nullptr if no symbol could 163 /// be found. If provided, `overlappedRange` is set to the range that the 164 /// provided `loc` overlapped with. 165 const PDLIndexSymbol *lookup(SMLoc loc, 166 SMRange *overlappedRange = nullptr) const; 167 168 private: 169 /// The type of interval map used to store source references. SMRange is 170 /// half-open, so we also need to use a half-open interval map. 171 using MapT = 172 llvm::IntervalMap<const char *, const PDLIndexSymbol *, 173 llvm::IntervalMapImpl::NodeSizer< 174 const char *, const PDLIndexSymbol *>::LeafSize, 175 llvm::IntervalMapHalfOpenInfo<const char *>>; 176 177 /// An allocator for the interval map. 178 MapT::Allocator allocator; 179 180 /// An interval map containing a corresponding definition mapped to a source 181 /// interval. 182 MapT intervalMap; 183 184 /// A mapping between definitions and their corresponding symbol. 185 DenseMap<const void *, std::unique_ptr<PDLIndexSymbol>> defToSymbol; 186 }; 187 } // namespace 188 189 void PDLIndex::initialize(const ast::Module &module, 190 const ods::Context &odsContext) { 191 auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * { 192 auto it = defToSymbol.try_emplace(def, nullptr); 193 if (it.second) 194 it.first->second = std::make_unique<PDLIndexSymbol>(def); 195 return &*it.first->second; 196 }; 197 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc, 198 bool isDef = false) { 199 const char *startLoc = refLoc.Start.getPointer(); 200 const char *endLoc = refLoc.End.getPointer(); 201 if (!intervalMap.overlaps(startLoc, endLoc)) { 202 intervalMap.insert(startLoc, endLoc, sym); 203 if (!isDef) 204 sym->references.push_back(refLoc); 205 } 206 }; 207 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) { 208 const ods::Operation *odsOp = odsContext.lookupOperation(opName); 209 if (!odsOp) 210 return; 211 212 PDLIndexSymbol *symbol = getOrInsertDef(odsOp); 213 insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true); 214 insertDeclRef(symbol, refLoc); 215 }; 216 217 module.walk([&](const ast::Node *node) { 218 // Handle references to PDL decls. 219 if (const auto *decl = dyn_cast<ast::OpNameDecl>(node)) { 220 if (Optional<StringRef> name = decl->getName()) 221 insertODSOpRef(*name, decl->getLoc()); 222 } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(node)) { 223 const ast::Name *name = decl->getName(); 224 if (!name) 225 return; 226 PDLIndexSymbol *declSym = getOrInsertDef(decl); 227 insertDeclRef(declSym, name->getLoc(), /*isDef=*/true); 228 229 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) { 230 // Record references to any constraints. 231 for (const auto &it : varDecl->getConstraints()) 232 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc); 233 } 234 } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) { 235 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc()); 236 } 237 }); 238 } 239 240 const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc, 241 SMRange *overlappedRange) const { 242 auto it = intervalMap.find(loc.getPointer()); 243 if (!it.valid() || loc.getPointer() < it.start()) 244 return nullptr; 245 246 if (overlappedRange) { 247 *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()), 248 SMLoc::getFromPointer(it.stop())); 249 } 250 return it.value(); 251 } 252 253 //===----------------------------------------------------------------------===// 254 // PDLDocument 255 //===----------------------------------------------------------------------===// 256 257 namespace { 258 /// This class represents all of the information pertaining to a specific PDL 259 /// document. 260 struct PDLDocument { 261 PDLDocument(const lsp::URIForFile &uri, StringRef contents, 262 const std::vector<std::string> &extraDirs, 263 std::vector<lsp::Diagnostic> &diagnostics); 264 PDLDocument(const PDLDocument &) = delete; 265 PDLDocument &operator=(const PDLDocument &) = delete; 266 267 //===--------------------------------------------------------------------===// 268 // Definitions and References 269 //===--------------------------------------------------------------------===// 270 271 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, 272 std::vector<lsp::Location> &locations); 273 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, 274 std::vector<lsp::Location> &references); 275 276 //===--------------------------------------------------------------------===// 277 // Document Links 278 //===--------------------------------------------------------------------===// 279 280 void getDocumentLinks(const lsp::URIForFile &uri, 281 std::vector<lsp::DocumentLink> &links); 282 283 //===--------------------------------------------------------------------===// 284 // Hover 285 //===--------------------------------------------------------------------===// 286 287 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 288 const lsp::Position &hoverPos); 289 Optional<lsp::Hover> findHover(const ast::Decl *decl, 290 const SMRange &hoverRange); 291 lsp::Hover buildHoverForInclude(const PDLLInclude &include); 292 lsp::Hover buildHoverForOpName(const ods::Operation *op, 293 const SMRange &hoverRange); 294 lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl, 295 const SMRange &hoverRange); 296 lsp::Hover buildHoverForPattern(const ast::PatternDecl *patternDecl, 297 const SMRange &hoverRange); 298 lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, 299 const SMRange &hoverRange); 300 template <typename T> 301 lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName, 302 const T *decl, 303 const SMRange &hoverRange); 304 305 //===--------------------------------------------------------------------===// 306 // Document Symbols 307 //===--------------------------------------------------------------------===// 308 309 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 310 311 //===--------------------------------------------------------------------===// 312 // Code Completion 313 //===--------------------------------------------------------------------===// 314 315 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, 316 const lsp::Position &completePos); 317 318 //===--------------------------------------------------------------------===// 319 // Signature Help 320 //===--------------------------------------------------------------------===// 321 322 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, 323 const lsp::Position &helpPos); 324 325 //===--------------------------------------------------------------------===// 326 // Fields 327 //===--------------------------------------------------------------------===// 328 329 /// The include directories for this file. 330 std::vector<std::string> includeDirs; 331 332 /// The source manager containing the contents of the input file. 333 llvm::SourceMgr sourceMgr; 334 335 /// The ODS and AST contexts. 336 ods::Context odsContext; 337 ast::Context astContext; 338 339 /// The parsed AST module, or failure if the file wasn't valid. 340 FailureOr<ast::Module *> astModule; 341 342 /// The index of the parsed module. 343 PDLIndex index; 344 345 /// The set of includes of the parsed module. 346 std::vector<PDLLInclude> parsedIncludes; 347 }; 348 } // namespace 349 350 PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents, 351 const std::vector<std::string> &extraDirs, 352 std::vector<lsp::Diagnostic> &diagnostics) 353 : astContext(odsContext) { 354 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); 355 if (!memBuffer) { 356 lsp::Logger::error("Failed to create memory buffer for file", uri.file()); 357 return; 358 } 359 360 // Build the set of include directories for this file. 361 llvm::SmallString<32> uriDirectory(uri.file()); 362 llvm::sys::path::remove_filename(uriDirectory); 363 includeDirs.push_back(uriDirectory.str().str()); 364 includeDirs.insert(includeDirs.end(), extraDirs.begin(), extraDirs.end()); 365 366 sourceMgr.setIncludeDirs(includeDirs); 367 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 368 369 astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) { 370 if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri)) 371 diagnostics.push_back(std::move(*lspDiag)); 372 }); 373 astModule = parsePDLAST(astContext, sourceMgr); 374 375 // Initialize the set of parsed includes. 376 for (unsigned i = 1, e = sourceMgr.getNumBuffers(); i < e; ++i) { 377 // Check to see if this file was included by the main file. 378 SMLoc includeLoc = sourceMgr.getBufferInfo(i + 1).IncludeLoc; 379 if (!includeLoc.isValid() || sourceMgr.FindBufferContainingLoc( 380 includeLoc) != sourceMgr.getMainFileID()) 381 continue; 382 383 // Try to build a URI for this file path. 384 auto *buffer = sourceMgr.getMemoryBuffer(i + 1); 385 llvm::SmallString<256> path(buffer->getBufferIdentifier()); 386 llvm::sys::path::remove_dots(path, /*remove_dot_dot=*/true); 387 388 llvm::Expected<lsp::URIForFile> includedFileURI = 389 lsp::URIForFile::fromFile(path); 390 if (!includedFileURI) 391 continue; 392 393 // Find the end of the include token. 394 const char *includeStart = includeLoc.getPointer() - 2; 395 while (*(--includeStart) != '\"') 396 continue; 397 398 // Push this include. 399 SMRange includeRange(SMLoc::getFromPointer(includeStart), includeLoc); 400 parsedIncludes.emplace_back(*includedFileURI, 401 lsp::Range(sourceMgr, includeRange)); 402 } 403 404 // If we failed to parse the module, there is nothing left to initialize. 405 if (failed(astModule)) 406 return; 407 408 // Prepare the AST index with the parsed module. 409 index.initialize(**astModule, odsContext); 410 } 411 412 //===----------------------------------------------------------------------===// 413 // PDLDocument: Definitions and References 414 //===----------------------------------------------------------------------===// 415 416 void PDLDocument::getLocationsOf(const lsp::URIForFile &uri, 417 const lsp::Position &defPos, 418 std::vector<lsp::Location> &locations) { 419 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); 420 const PDLIndexSymbol *symbol = index.lookup(posLoc); 421 if (!symbol) 422 return; 423 424 locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri)); 425 } 426 427 void PDLDocument::findReferencesOf(const lsp::URIForFile &uri, 428 const lsp::Position &pos, 429 std::vector<lsp::Location> &references) { 430 SMLoc posLoc = pos.getAsSMLoc(sourceMgr); 431 const PDLIndexSymbol *symbol = index.lookup(posLoc); 432 if (!symbol) 433 return; 434 435 references.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri)); 436 for (SMRange refLoc : symbol->references) 437 references.push_back(getLocationFromLoc(sourceMgr, refLoc, uri)); 438 } 439 440 //===--------------------------------------------------------------------===// 441 // PDLDocument: Document Links 442 //===--------------------------------------------------------------------===// 443 444 void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri, 445 std::vector<lsp::DocumentLink> &links) { 446 for (const PDLLInclude &include : parsedIncludes) 447 links.emplace_back(include.range, include.uri); 448 } 449 450 //===----------------------------------------------------------------------===// 451 // PDLDocument: Hover 452 //===----------------------------------------------------------------------===// 453 454 Optional<lsp::Hover> PDLDocument::findHover(const lsp::URIForFile &uri, 455 const lsp::Position &hoverPos) { 456 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr); 457 458 // Check for a reference to an include. 459 for (const PDLLInclude &include : parsedIncludes) { 460 if (include.range.contains(hoverPos)) 461 return buildHoverForInclude(include); 462 } 463 464 // Find the symbol at the given location. 465 SMRange hoverRange; 466 const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange); 467 if (!symbol) 468 return llvm::None; 469 470 // Add hover for operation names. 471 if (const auto *op = symbol->definition.dyn_cast<const ods::Operation *>()) 472 return buildHoverForOpName(op, hoverRange); 473 const auto *decl = symbol->definition.get<const ast::Decl *>(); 474 return findHover(decl, hoverRange); 475 } 476 477 Optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl, 478 const SMRange &hoverRange) { 479 // Add hover for variables. 480 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 481 return buildHoverForVariable(varDecl, hoverRange); 482 483 // Add hover for patterns. 484 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) 485 return buildHoverForPattern(patternDecl, hoverRange); 486 487 // Add hover for core constraints. 488 if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl)) 489 return buildHoverForCoreConstraint(cst, hoverRange); 490 491 // Add hover for user constraints. 492 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) 493 return buildHoverForUserConstraintOrRewrite("Constraint", cst, hoverRange); 494 495 // Add hover for user rewrites. 496 if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(decl)) 497 return buildHoverForUserConstraintOrRewrite("Rewrite", rewrite, hoverRange); 498 499 return llvm::None; 500 } 501 502 lsp::Hover PDLDocument::buildHoverForInclude(const PDLLInclude &include) { 503 lsp::Hover hover(include.range); 504 { 505 llvm::raw_string_ostream hoverOS(hover.contents.value); 506 hoverOS << "`" << llvm::sys::path::filename(include.uri.file()) 507 << "`\n***\n" 508 << include.uri.file(); 509 } 510 return hover; 511 } 512 513 lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, 514 const SMRange &hoverRange) { 515 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 516 { 517 llvm::raw_string_ostream hoverOS(hover.contents.value); 518 hoverOS << "**OpName**: `" << op->getName() << "`\n***\n" 519 << op->getSummary() << "\n***\n" 520 << op->getDescription(); 521 } 522 return hover; 523 } 524 525 lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, 526 const SMRange &hoverRange) { 527 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 528 { 529 llvm::raw_string_ostream hoverOS(hover.contents.value); 530 hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n" 531 << "Type: `" << varDecl->getType() << "`\n"; 532 } 533 return hover; 534 } 535 536 lsp::Hover 537 PDLDocument::buildHoverForPattern(const ast::PatternDecl *patternDecl, 538 const SMRange &hoverRange) { 539 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 540 { 541 llvm::raw_string_ostream hoverOS(hover.contents.value); 542 hoverOS << "**Pattern**"; 543 if (const ast::Name *name = patternDecl->getName()) 544 hoverOS << ": `" << name->getName() << "`"; 545 hoverOS << "\n***\n"; 546 if (Optional<uint16_t> benefit = patternDecl->getBenefit()) 547 hoverOS << "Benefit: " << *benefit << "\n"; 548 if (patternDecl->hasBoundedRewriteRecursion()) 549 hoverOS << "HasBoundedRewriteRecursion\n"; 550 hoverOS << "RootOp: `" 551 << patternDecl->getRootRewriteStmt()->getRootOpExpr()->getType() 552 << "`\n"; 553 } 554 return hover; 555 } 556 557 lsp::Hover 558 PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, 559 const SMRange &hoverRange) { 560 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 561 { 562 llvm::raw_string_ostream hoverOS(hover.contents.value); 563 hoverOS << "**Constraint**: `"; 564 TypeSwitch<const ast::Decl *>(decl) 565 .Case([&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; }) 566 .Case([&](const ast::OpConstraintDecl *opCst) { 567 hoverOS << "Op"; 568 if (Optional<StringRef> name = opCst->getName()) 569 hoverOS << "<" << name << ">"; 570 }) 571 .Case([&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; }) 572 .Case([&](const ast::TypeRangeConstraintDecl *) { 573 hoverOS << "TypeRange"; 574 }) 575 .Case([&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; }) 576 .Case([&](const ast::ValueRangeConstraintDecl *) { 577 hoverOS << "ValueRange"; 578 }); 579 hoverOS << "`\n"; 580 } 581 return hover; 582 } 583 584 template <typename T> 585 lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( 586 StringRef typeName, const T *decl, const SMRange &hoverRange) { 587 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); 588 { 589 llvm::raw_string_ostream hoverOS(hover.contents.value); 590 hoverOS << "**" << typeName << "**: `" << decl->getName().getName() 591 << "`\n***\n"; 592 ArrayRef<ast::VariableDecl *> inputs = decl->getInputs(); 593 if (!inputs.empty()) { 594 hoverOS << "Parameters:\n"; 595 for (const ast::VariableDecl *input : inputs) 596 hoverOS << "* " << input->getName().getName() << ": `" 597 << input->getType() << "`\n"; 598 hoverOS << "***\n"; 599 } 600 ast::Type resultType = decl->getResultType(); 601 if (auto resultTupleTy = resultType.dyn_cast<ast::TupleType>()) { 602 if (resultTupleTy.empty()) 603 return hover; 604 605 hoverOS << "Results:\n"; 606 for (auto it : llvm::zip(resultTupleTy.getElementNames(), 607 resultTupleTy.getElementTypes())) { 608 StringRef name = std::get<0>(it); 609 hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`" 610 << std::get<1>(it) << "`\n"; 611 } 612 } else { 613 hoverOS << "Results:\n* `" << resultType << "`\n"; 614 } 615 hoverOS << "***\n"; 616 } 617 return hover; 618 } 619 620 //===----------------------------------------------------------------------===// 621 // PDLDocument: Document Symbols 622 //===----------------------------------------------------------------------===// 623 624 void PDLDocument::findDocumentSymbols( 625 std::vector<lsp::DocumentSymbol> &symbols) { 626 if (failed(astModule)) 627 return; 628 629 for (const ast::Decl *decl : (*astModule)->getChildren()) { 630 if (!isMainFileLoc(sourceMgr, decl->getLoc())) 631 continue; 632 633 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) { 634 const ast::Name *name = patternDecl->getName(); 635 636 SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc(); 637 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End); 638 639 symbols.emplace_back( 640 name ? name->getName() : "<pattern>", lsp::SymbolKind::Class, 641 lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); 642 } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) { 643 // TODO: Add source information for the code block body. 644 SMRange nameLoc = cDecl->getName().getLoc(); 645 SMRange bodyLoc = nameLoc; 646 647 symbols.emplace_back( 648 cDecl->getName().getName(), lsp::SymbolKind::Function, 649 lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); 650 } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) { 651 // TODO: Add source information for the code block body. 652 SMRange nameLoc = cDecl->getName().getLoc(); 653 SMRange bodyLoc = nameLoc; 654 655 symbols.emplace_back( 656 cDecl->getName().getName(), lsp::SymbolKind::Function, 657 lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); 658 } 659 } 660 } 661 662 //===----------------------------------------------------------------------===// 663 // PDLDocument: Code Completion 664 //===----------------------------------------------------------------------===// 665 666 namespace { 667 class LSPCodeCompleteContext : public CodeCompleteContext { 668 public: 669 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList, 670 ods::Context &odsContext, 671 ArrayRef<std::string> includeDirs) 672 : CodeCompleteContext(completeLoc), completionList(completionList), 673 odsContext(odsContext), includeDirs(includeDirs) {} 674 675 void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final { 676 ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes(); 677 ArrayRef<StringRef> elementNames = tupleType.getElementNames(); 678 for (unsigned i = 0, e = tupleType.size(); i < e; ++i) { 679 // Push back a completion item that uses the result index. 680 lsp::CompletionItem item; 681 item.label = llvm::formatv("{0} (field #{0})", i).str(); 682 item.insertText = Twine(i).str(); 683 item.filterText = item.sortText = item.insertText; 684 item.kind = lsp::CompletionItemKind::Field; 685 item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]); 686 item.insertTextFormat = lsp::InsertTextFormat::PlainText; 687 completionList.items.emplace_back(item); 688 689 // If the element has a name, push back a completion item with that name. 690 if (!elementNames[i].empty()) { 691 item.label = 692 llvm::formatv("{1} (field #{0})", i, elementNames[i]).str(); 693 item.filterText = item.label; 694 item.insertText = elementNames[i].str(); 695 completionList.items.emplace_back(item); 696 } 697 } 698 } 699 700 void codeCompleteOperationMemberAccess(ast::OperationType opType) final { 701 Optional<StringRef> opName = opType.getName(); 702 const ods::Operation *odsOp = 703 opName ? odsContext.lookupOperation(*opName) : nullptr; 704 if (!odsOp) 705 return; 706 707 ArrayRef<ods::OperandOrResult> results = odsOp->getResults(); 708 for (const auto &it : llvm::enumerate(results)) { 709 const ods::OperandOrResult &result = it.value(); 710 const ods::TypeConstraint &constraint = result.getConstraint(); 711 712 // Push back a completion item that uses the result index. 713 lsp::CompletionItem item; 714 item.label = llvm::formatv("{0} (field #{0})", it.index()).str(); 715 item.insertText = Twine(it.index()).str(); 716 item.filterText = item.sortText = item.insertText; 717 item.kind = lsp::CompletionItemKind::Field; 718 switch (result.getVariableLengthKind()) { 719 case ods::VariableLengthKind::Single: 720 item.detail = llvm::formatv("{0}: Value", it.index()).str(); 721 break; 722 case ods::VariableLengthKind::Optional: 723 item.detail = llvm::formatv("{0}: Value?", it.index()).str(); 724 break; 725 case ods::VariableLengthKind::Variadic: 726 item.detail = llvm::formatv("{0}: ValueRange", it.index()).str(); 727 break; 728 } 729 item.documentation = lsp::MarkupContent{ 730 lsp::MarkupKind::Markdown, 731 llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(), 732 constraint.getCppClass()) 733 .str()}; 734 item.insertTextFormat = lsp::InsertTextFormat::PlainText; 735 completionList.items.emplace_back(item); 736 737 // If the result has a name, push back a completion item with the result 738 // name. 739 if (!result.getName().empty()) { 740 item.label = 741 llvm::formatv("{1} (field #{0})", it.index(), result.getName()) 742 .str(); 743 item.filterText = item.label; 744 item.insertText = result.getName().str(); 745 completionList.items.emplace_back(item); 746 } 747 } 748 } 749 750 void codeCompleteOperationAttributeName(StringRef opName) final { 751 const ods::Operation *odsOp = odsContext.lookupOperation(opName); 752 if (!odsOp) 753 return; 754 755 for (const ods::Attribute &attr : odsOp->getAttributes()) { 756 const ods::AttributeConstraint &constraint = attr.getConstraint(); 757 758 lsp::CompletionItem item; 759 item.label = attr.getName().str(); 760 item.kind = lsp::CompletionItemKind::Field; 761 item.detail = attr.isOptional() ? "optional" : ""; 762 item.documentation = lsp::MarkupContent{ 763 lsp::MarkupKind::Markdown, 764 llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(), 765 constraint.getCppClass()) 766 .str()}; 767 item.insertTextFormat = lsp::InsertTextFormat::PlainText; 768 completionList.items.emplace_back(item); 769 } 770 } 771 772 void codeCompleteConstraintName(ast::Type currentType, 773 bool allowNonCoreConstraints, 774 bool allowInlineTypeConstraints, 775 const ast::DeclScope *scope) final { 776 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType, 777 StringRef snippetText = "") { 778 lsp::CompletionItem item; 779 item.label = constraint.str(); 780 item.kind = lsp::CompletionItemKind::Class; 781 item.detail = (constraint + " constraint").str(); 782 item.documentation = lsp::MarkupContent{ 783 lsp::MarkupKind::Markdown, 784 ("A single entity core constraint of type `" + mlirType + "`").str()}; 785 item.sortText = "0"; 786 item.insertText = snippetText.str(); 787 item.insertTextFormat = snippetText.empty() 788 ? lsp::InsertTextFormat::PlainText 789 : lsp::InsertTextFormat::Snippet; 790 completionList.items.emplace_back(item); 791 }; 792 793 // Insert completions for the core constraints. Some core constraints have 794 // additional characteristics, so we may add then even if a type has been 795 // inferred. 796 if (!currentType) { 797 addCoreConstraint("Attr", "mlir::Attribute"); 798 addCoreConstraint("Op", "mlir::Operation *"); 799 addCoreConstraint("Value", "mlir::Value"); 800 addCoreConstraint("ValueRange", "mlir::ValueRange"); 801 addCoreConstraint("Type", "mlir::Type"); 802 addCoreConstraint("TypeRange", "mlir::TypeRange"); 803 } 804 if (allowInlineTypeConstraints) { 805 /// Attr<Type>. 806 if (!currentType || currentType.isa<ast::AttributeType>()) 807 addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>"); 808 /// Value<Type>. 809 if (!currentType || currentType.isa<ast::ValueType>()) 810 addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>"); 811 /// ValueRange<TypeRange>. 812 if (!currentType || currentType.isa<ast::ValueRangeType>()) 813 addCoreConstraint("ValueRange<type>", "mlir::ValueRange", 814 "ValueRange<$1>"); 815 } 816 817 // If a scope was provided, check it for potential constraints. 818 while (scope) { 819 for (const ast::Decl *decl : scope->getDecls()) { 820 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) { 821 if (!allowNonCoreConstraints) 822 continue; 823 824 lsp::CompletionItem item; 825 item.label = cst->getName().getName().str(); 826 item.kind = lsp::CompletionItemKind::Interface; 827 item.sortText = "2_" + item.label; 828 829 // Skip constraints that are not single-arg. We currently only 830 // complete variable constraints. 831 if (cst->getInputs().size() != 1) 832 continue; 833 834 // Ensure the input type matched the given type. 835 ast::Type constraintType = cst->getInputs()[0]->getType(); 836 if (currentType && !currentType.refineWith(constraintType)) 837 continue; 838 839 // Format the constraint signature. 840 { 841 llvm::raw_string_ostream strOS(item.detail); 842 strOS << "("; 843 llvm::interleaveComma( 844 cst->getInputs(), strOS, [&](const ast::VariableDecl *var) { 845 strOS << var->getName().getName() << ": " << var->getType(); 846 }); 847 strOS << ") -> " << cst->getResultType(); 848 } 849 850 completionList.items.emplace_back(item); 851 } 852 } 853 854 scope = scope->getParentScope(); 855 } 856 } 857 858 void codeCompleteDialectName() final { 859 // Code complete known dialects. 860 for (const ods::Dialect &dialect : odsContext.getDialects()) { 861 lsp::CompletionItem item; 862 item.label = dialect.getName().str(); 863 item.kind = lsp::CompletionItemKind::Class; 864 item.insertTextFormat = lsp::InsertTextFormat::PlainText; 865 completionList.items.emplace_back(item); 866 } 867 } 868 869 void codeCompleteOperationName(StringRef dialectName) final { 870 const ods::Dialect *dialect = odsContext.lookupDialect(dialectName); 871 if (!dialect) 872 return; 873 874 for (const auto &it : dialect->getOperations()) { 875 const ods::Operation &op = *it.second; 876 877 lsp::CompletionItem item; 878 item.label = op.getName().drop_front(dialectName.size() + 1).str(); 879 item.kind = lsp::CompletionItemKind::Field; 880 item.insertTextFormat = lsp::InsertTextFormat::PlainText; 881 completionList.items.emplace_back(item); 882 } 883 } 884 885 void codeCompletePatternMetadata() final { 886 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc, 887 StringRef snippetText = "") { 888 lsp::CompletionItem item; 889 item.label = constraint.str(); 890 item.kind = lsp::CompletionItemKind::Class; 891 item.detail = "pattern metadata"; 892 item.documentation = 893 lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()}; 894 item.insertText = snippetText.str(); 895 item.insertTextFormat = snippetText.empty() 896 ? lsp::InsertTextFormat::PlainText 897 : lsp::InsertTextFormat::Snippet; 898 completionList.items.emplace_back(item); 899 }; 900 901 addSimpleConstraint("benefit", "The `benefit` of matching the pattern.", 902 "benefit($1)"); 903 addSimpleConstraint("recursion", 904 "The pattern properly handles recursive application."); 905 } 906 907 void codeCompleteIncludeFilename(StringRef curPath) final { 908 // Normalize the path to allow for interacting with the file system 909 // utilities. 910 SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(curPath)); 911 llvm::sys::path::native(nativeRelDir); 912 913 // Set of already included completion paths. 914 StringSet<> seenResults; 915 916 // Functor used to add a single include completion item. 917 auto addIncludeCompletion = [&](StringRef path, bool isDirectory) { 918 lsp::CompletionItem item; 919 item.label = (path + (isDirectory ? "/" : "")).str(); 920 item.kind = isDirectory ? lsp::CompletionItemKind::Folder 921 : lsp::CompletionItemKind::File; 922 if (seenResults.insert(item.label).second) 923 completionList.items.emplace_back(item); 924 }; 925 926 // Process the include directories for this file, adding any potential 927 // nested include files or directories. 928 for (StringRef includeDir : includeDirs) { 929 llvm::SmallString<128> dir = includeDir; 930 if (!nativeRelDir.empty()) 931 llvm::sys::path::append(dir, nativeRelDir); 932 933 std::error_code errorCode; 934 for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode), 935 e = llvm::sys::fs::directory_iterator(); 936 !errorCode && it != e; it.increment(errorCode)) { 937 StringRef filename = llvm::sys::path::filename(it->path()); 938 939 // To know whether a symlink should be treated as file or a directory, 940 // we have to stat it. This should be cheap enough as there shouldn't be 941 // many symlinks. 942 llvm::sys::fs::file_type fileType = it->type(); 943 if (fileType == llvm::sys::fs::file_type::symlink_file) { 944 if (auto fileStatus = it->status()) 945 fileType = fileStatus->type(); 946 } 947 948 switch (fileType) { 949 case llvm::sys::fs::file_type::directory_file: 950 addIncludeCompletion(filename, /*isDirectory=*/true); 951 break; 952 case llvm::sys::fs::file_type::regular_file: { 953 // Only consider concrete files that can actually be included by PDLL. 954 if (filename.endswith(".pdll") || filename.endswith(".td")) 955 addIncludeCompletion(filename, /*isDirectory=*/false); 956 break; 957 } 958 default: 959 break; 960 } 961 } 962 } 963 964 // Sort the completion results to make sure the output is deterministic in 965 // the face of different iteration schemes for different platforms. 966 llvm::sort(completionList.items, [](const lsp::CompletionItem &lhs, 967 const lsp::CompletionItem &rhs) { 968 return lhs.label < rhs.label; 969 }); 970 } 971 972 private: 973 lsp::CompletionList &completionList; 974 ods::Context &odsContext; 975 ArrayRef<std::string> includeDirs; 976 }; 977 } // namespace 978 979 lsp::CompletionList 980 PDLDocument::getCodeCompletion(const lsp::URIForFile &uri, 981 const lsp::Position &completePos) { 982 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr); 983 if (!posLoc.isValid()) 984 return lsp::CompletionList(); 985 986 // Adjust the position one further to after the completion trigger token. 987 posLoc = SMLoc::getFromPointer(posLoc.getPointer() + 1); 988 989 // To perform code completion, we run another parse of the module with the 990 // code completion context provided. 991 ods::Context tmpODSContext; 992 lsp::CompletionList completionList; 993 LSPCodeCompleteContext lspCompleteContext( 994 posLoc, completionList, tmpODSContext, sourceMgr.getIncludeDirs()); 995 996 ast::Context tmpContext(tmpODSContext); 997 (void)parsePDLAST(tmpContext, sourceMgr, &lspCompleteContext); 998 999 return completionList; 1000 } 1001 1002 //===----------------------------------------------------------------------===// 1003 // PDLDocument: Signature Help 1004 //===----------------------------------------------------------------------===// 1005 1006 namespace { 1007 class LSPSignatureHelpContext : public CodeCompleteContext { 1008 public: 1009 LSPSignatureHelpContext(SMLoc completeLoc, lsp::SignatureHelp &signatureHelp, 1010 ods::Context &odsContext) 1011 : CodeCompleteContext(completeLoc), signatureHelp(signatureHelp), 1012 odsContext(odsContext) {} 1013 1014 void codeCompleteCallSignature(const ast::CallableDecl *callable, 1015 unsigned currentNumArgs) final { 1016 signatureHelp.activeParameter = currentNumArgs; 1017 1018 lsp::SignatureInformation signatureInfo; 1019 { 1020 llvm::raw_string_ostream strOS(signatureInfo.label); 1021 strOS << callable->getName()->getName() << "("; 1022 auto formatParamFn = [&](const ast::VariableDecl *var) { 1023 unsigned paramStart = strOS.str().size(); 1024 strOS << var->getName().getName() << ": " << var->getType(); 1025 unsigned paramEnd = strOS.str().size(); 1026 signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ 1027 StringRef(strOS.str()).slice(paramStart, paramEnd).str(), 1028 std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()}); 1029 }; 1030 llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn); 1031 strOS << ") -> " << callable->getResultType(); 1032 } 1033 signatureHelp.signatures.emplace_back(std::move(signatureInfo)); 1034 } 1035 1036 void 1037 codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 1038 unsigned currentNumOperands) final { 1039 const ods::Operation *odsOp = 1040 opName ? odsContext.lookupOperation(*opName) : nullptr; 1041 codeCompleteOperationOperandOrResultSignature( 1042 opName, odsOp, odsOp ? odsOp->getOperands() : llvm::None, 1043 currentNumOperands, "operand", "Value"); 1044 } 1045 1046 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 1047 unsigned currentNumResults) final { 1048 const ods::Operation *odsOp = 1049 opName ? odsContext.lookupOperation(*opName) : nullptr; 1050 codeCompleteOperationOperandOrResultSignature( 1051 opName, odsOp, odsOp ? odsOp->getResults() : llvm::None, 1052 currentNumResults, "result", "Type"); 1053 } 1054 1055 void codeCompleteOperationOperandOrResultSignature( 1056 Optional<StringRef> opName, const ods::Operation *odsOp, 1057 ArrayRef<ods::OperandOrResult> values, unsigned currentValue, 1058 StringRef label, StringRef dataType) { 1059 signatureHelp.activeParameter = currentValue; 1060 1061 // If we have ODS information for the operation, add in the ODS signature 1062 // for the operation. We also verify that the current number of values is 1063 // not more than what is defined in ODS, as this will result in an error 1064 // anyways. 1065 if (odsOp && currentValue < values.size()) { 1066 lsp::SignatureInformation signatureInfo; 1067 1068 // Build the signature label. 1069 { 1070 llvm::raw_string_ostream strOS(signatureInfo.label); 1071 strOS << "("; 1072 auto formatFn = [&](const ods::OperandOrResult &value) { 1073 unsigned paramStart = strOS.str().size(); 1074 1075 strOS << value.getName() << ": "; 1076 1077 StringRef constraintDoc = value.getConstraint().getSummary(); 1078 std::string paramDoc; 1079 switch (value.getVariableLengthKind()) { 1080 case ods::VariableLengthKind::Single: 1081 strOS << dataType; 1082 paramDoc = constraintDoc.str(); 1083 break; 1084 case ods::VariableLengthKind::Optional: 1085 strOS << dataType << "?"; 1086 paramDoc = ("optional: " + constraintDoc).str(); 1087 break; 1088 case ods::VariableLengthKind::Variadic: 1089 strOS << dataType << "Range"; 1090 paramDoc = ("variadic: " + constraintDoc).str(); 1091 break; 1092 } 1093 1094 unsigned paramEnd = strOS.str().size(); 1095 signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ 1096 StringRef(strOS.str()).slice(paramStart, paramEnd).str(), 1097 std::make_pair(paramStart, paramEnd), paramDoc}); 1098 }; 1099 llvm::interleaveComma(values, strOS, formatFn); 1100 strOS << ")"; 1101 } 1102 signatureInfo.documentation = 1103 llvm::formatv("`op<{0}>` ODS {1} specification", *opName, label) 1104 .str(); 1105 signatureHelp.signatures.emplace_back(std::move(signatureInfo)); 1106 } 1107 1108 // If there aren't any arguments yet, we also add the generic signature. 1109 if (currentValue == 0 && (!odsOp || !values.empty())) { 1110 lsp::SignatureInformation signatureInfo; 1111 signatureInfo.label = 1112 llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str(); 1113 signatureInfo.documentation = 1114 ("Generic operation " + label + " specification").str(); 1115 signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ 1116 StringRef(signatureInfo.label).drop_front().drop_back().str(), 1117 std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1), 1118 ("All of the " + label + "s of the operation.").str()}); 1119 signatureHelp.signatures.emplace_back(std::move(signatureInfo)); 1120 } 1121 } 1122 1123 private: 1124 lsp::SignatureHelp &signatureHelp; 1125 ods::Context &odsContext; 1126 }; 1127 } // namespace 1128 1129 lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri, 1130 const lsp::Position &helpPos) { 1131 SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr); 1132 if (!posLoc.isValid()) 1133 return lsp::SignatureHelp(); 1134 1135 // Adjust the position one further to after the completion trigger token. 1136 posLoc = SMLoc::getFromPointer(posLoc.getPointer() + 1); 1137 1138 // To perform code completion, we run another parse of the module with the 1139 // code completion context provided. 1140 ods::Context tmpODSContext; 1141 lsp::SignatureHelp signatureHelp; 1142 LSPSignatureHelpContext completeContext(posLoc, signatureHelp, tmpODSContext); 1143 1144 ast::Context tmpContext(tmpODSContext); 1145 (void)parsePDLAST(tmpContext, sourceMgr, &completeContext); 1146 1147 return signatureHelp; 1148 } 1149 1150 //===----------------------------------------------------------------------===// 1151 // PDLTextFileChunk 1152 //===----------------------------------------------------------------------===// 1153 1154 namespace { 1155 /// This class represents a single chunk of an PDL text file. 1156 struct PDLTextFileChunk { 1157 PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri, 1158 StringRef contents, 1159 const std::vector<std::string> &extraDirs, 1160 std::vector<lsp::Diagnostic> &diagnostics) 1161 : lineOffset(lineOffset), 1162 document(uri, contents, extraDirs, diagnostics) {} 1163 1164 /// Adjust the line number of the given range to anchor at the beginning of 1165 /// the file, instead of the beginning of this chunk. 1166 void adjustLocForChunkOffset(lsp::Range &range) { 1167 adjustLocForChunkOffset(range.start); 1168 adjustLocForChunkOffset(range.end); 1169 } 1170 /// Adjust the line number of the given position to anchor at the beginning of 1171 /// the file, instead of the beginning of this chunk. 1172 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } 1173 1174 /// The line offset of this chunk from the beginning of the file. 1175 uint64_t lineOffset; 1176 /// The document referred to by this chunk. 1177 PDLDocument document; 1178 }; 1179 } // namespace 1180 1181 //===----------------------------------------------------------------------===// 1182 // PDLTextFile 1183 //===----------------------------------------------------------------------===// 1184 1185 namespace { 1186 /// This class represents a text file containing one or more PDL documents. 1187 class PDLTextFile { 1188 public: 1189 PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents, 1190 int64_t version, const std::vector<std::string> &extraDirs, 1191 std::vector<lsp::Diagnostic> &diagnostics); 1192 1193 /// Return the current version of this text file. 1194 int64_t getVersion() const { return version; } 1195 1196 //===--------------------------------------------------------------------===// 1197 // LSP Queries 1198 //===--------------------------------------------------------------------===// 1199 1200 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, 1201 std::vector<lsp::Location> &locations); 1202 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, 1203 std::vector<lsp::Location> &references); 1204 void getDocumentLinks(const lsp::URIForFile &uri, 1205 std::vector<lsp::DocumentLink> &links); 1206 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri, 1207 lsp::Position hoverPos); 1208 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); 1209 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, 1210 lsp::Position completePos); 1211 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, 1212 lsp::Position helpPos); 1213 1214 private: 1215 /// Find the PDL document that contains the given position, and update the 1216 /// position to be anchored at the start of the found chunk instead of the 1217 /// beginning of the file. 1218 PDLTextFileChunk &getChunkFor(lsp::Position &pos); 1219 1220 /// The full string contents of the file. 1221 std::string contents; 1222 1223 /// The version of this file. 1224 int64_t version; 1225 1226 /// The number of lines in the file. 1227 int64_t totalNumLines = 0; 1228 1229 /// The chunks of this file. The order of these chunks is the order in which 1230 /// they appear in the text file. 1231 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks; 1232 }; 1233 } // namespace 1234 1235 PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents, 1236 int64_t version, 1237 const std::vector<std::string> &extraDirs, 1238 std::vector<lsp::Diagnostic> &diagnostics) 1239 : contents(fileContents.str()), version(version) { 1240 // Split the file into separate PDL documents. 1241 // TODO: Find a way to share the split file marker with other tools. We don't 1242 // want to use `splitAndProcessBuffer` here, but we do want to make sure this 1243 // marker doesn't go out of sync. 1244 SmallVector<StringRef, 8> subContents; 1245 StringRef(contents).split(subContents, "// -----"); 1246 chunks.emplace_back(std::make_unique<PDLTextFileChunk>( 1247 /*lineOffset=*/0, uri, subContents.front(), extraDirs, diagnostics)); 1248 1249 uint64_t lineOffset = subContents.front().count('\n'); 1250 for (StringRef docContents : llvm::drop_begin(subContents)) { 1251 unsigned currentNumDiags = diagnostics.size(); 1252 auto chunk = std::make_unique<PDLTextFileChunk>( 1253 lineOffset, uri, docContents, extraDirs, diagnostics); 1254 lineOffset += docContents.count('\n'); 1255 1256 // Adjust locations used in diagnostics to account for the offset from the 1257 // beginning of the file. 1258 for (lsp::Diagnostic &diag : 1259 llvm::drop_begin(diagnostics, currentNumDiags)) { 1260 chunk->adjustLocForChunkOffset(diag.range); 1261 1262 if (!diag.relatedInformation) 1263 continue; 1264 for (auto &it : *diag.relatedInformation) 1265 if (it.location.uri == uri) 1266 chunk->adjustLocForChunkOffset(it.location.range); 1267 } 1268 chunks.emplace_back(std::move(chunk)); 1269 } 1270 totalNumLines = lineOffset; 1271 } 1272 1273 void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri, 1274 lsp::Position defPos, 1275 std::vector<lsp::Location> &locations) { 1276 PDLTextFileChunk &chunk = getChunkFor(defPos); 1277 chunk.document.getLocationsOf(uri, defPos, locations); 1278 1279 // Adjust any locations within this file for the offset of this chunk. 1280 if (chunk.lineOffset == 0) 1281 return; 1282 for (lsp::Location &loc : locations) 1283 if (loc.uri == uri) 1284 chunk.adjustLocForChunkOffset(loc.range); 1285 } 1286 1287 void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri, 1288 lsp::Position pos, 1289 std::vector<lsp::Location> &references) { 1290 PDLTextFileChunk &chunk = getChunkFor(pos); 1291 chunk.document.findReferencesOf(uri, pos, references); 1292 1293 // Adjust any locations within this file for the offset of this chunk. 1294 if (chunk.lineOffset == 0) 1295 return; 1296 for (lsp::Location &loc : references) 1297 if (loc.uri == uri) 1298 chunk.adjustLocForChunkOffset(loc.range); 1299 } 1300 1301 void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri, 1302 std::vector<lsp::DocumentLink> &links) { 1303 chunks.front()->document.getDocumentLinks(uri, links); 1304 for (const auto &it : llvm::drop_begin(chunks)) { 1305 size_t currentNumLinks = links.size(); 1306 it->document.getDocumentLinks(uri, links); 1307 1308 // Adjust any links within this file to account for the offset of this 1309 // chunk. 1310 for (auto &link : llvm::drop_begin(links, currentNumLinks)) 1311 it->adjustLocForChunkOffset(link.range); 1312 } 1313 } 1314 1315 Optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri, 1316 lsp::Position hoverPos) { 1317 PDLTextFileChunk &chunk = getChunkFor(hoverPos); 1318 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos); 1319 1320 // Adjust any locations within this file for the offset of this chunk. 1321 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) 1322 chunk.adjustLocForChunkOffset(*hoverInfo->range); 1323 return hoverInfo; 1324 } 1325 1326 void PDLTextFile::findDocumentSymbols( 1327 std::vector<lsp::DocumentSymbol> &symbols) { 1328 if (chunks.size() == 1) 1329 return chunks.front()->document.findDocumentSymbols(symbols); 1330 1331 // If there are multiple chunks in this file, we create top-level symbols for 1332 // each chunk. 1333 for (unsigned i = 0, e = chunks.size(); i < e; ++i) { 1334 PDLTextFileChunk &chunk = *chunks[i]; 1335 lsp::Position startPos(chunk.lineOffset); 1336 lsp::Position endPos((i == e - 1) ? totalNumLines - 1 1337 : chunks[i + 1]->lineOffset); 1338 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">", 1339 lsp::SymbolKind::Namespace, 1340 /*range=*/lsp::Range(startPos, endPos), 1341 /*selectionRange=*/lsp::Range(startPos)); 1342 chunk.document.findDocumentSymbols(symbol.children); 1343 1344 // Fixup the locations of document symbols within this chunk. 1345 if (i != 0) { 1346 SmallVector<lsp::DocumentSymbol *> symbolsToFix; 1347 for (lsp::DocumentSymbol &childSymbol : symbol.children) 1348 symbolsToFix.push_back(&childSymbol); 1349 1350 while (!symbolsToFix.empty()) { 1351 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); 1352 chunk.adjustLocForChunkOffset(symbol->range); 1353 chunk.adjustLocForChunkOffset(symbol->selectionRange); 1354 1355 for (lsp::DocumentSymbol &childSymbol : symbol->children) 1356 symbolsToFix.push_back(&childSymbol); 1357 } 1358 } 1359 1360 // Push the symbol for this chunk. 1361 symbols.emplace_back(std::move(symbol)); 1362 } 1363 } 1364 1365 lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri, 1366 lsp::Position completePos) { 1367 PDLTextFileChunk &chunk = getChunkFor(completePos); 1368 lsp::CompletionList completionList = 1369 chunk.document.getCodeCompletion(uri, completePos); 1370 1371 // Adjust any completion locations. 1372 for (lsp::CompletionItem &item : completionList.items) { 1373 if (item.textEdit) 1374 chunk.adjustLocForChunkOffset(item.textEdit->range); 1375 for (lsp::TextEdit &edit : item.additionalTextEdits) 1376 chunk.adjustLocForChunkOffset(edit.range); 1377 } 1378 return completionList; 1379 } 1380 1381 lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri, 1382 lsp::Position helpPos) { 1383 return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos); 1384 } 1385 1386 PDLTextFileChunk &PDLTextFile::getChunkFor(lsp::Position &pos) { 1387 if (chunks.size() == 1) 1388 return *chunks.front(); 1389 1390 // Search for the first chunk with a greater line offset, the previous chunk 1391 // is the one that contains `pos`. 1392 auto it = llvm::upper_bound( 1393 chunks, pos, [](const lsp::Position &pos, const auto &chunk) { 1394 return static_cast<uint64_t>(pos.line) < chunk->lineOffset; 1395 }); 1396 PDLTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it); 1397 pos.line -= chunk.lineOffset; 1398 return chunk; 1399 } 1400 1401 //===----------------------------------------------------------------------===// 1402 // PDLLServer::Impl 1403 //===----------------------------------------------------------------------===// 1404 1405 struct lsp::PDLLServer::Impl { 1406 explicit Impl(const Options &options) 1407 : options(options), compilationDatabase(options.compilationDatabases) {} 1408 1409 /// PDLL LSP options. 1410 const Options &options; 1411 1412 /// The compilation database containing additional information for files 1413 /// passed to the server. 1414 lsp::CompilationDatabase compilationDatabase; 1415 1416 /// The files held by the server, mapped by their URI file name. 1417 llvm::StringMap<std::unique_ptr<PDLTextFile>> files; 1418 }; 1419 1420 //===----------------------------------------------------------------------===// 1421 // PDLLServer 1422 //===----------------------------------------------------------------------===// 1423 1424 lsp::PDLLServer::PDLLServer(const Options &options) 1425 : impl(std::make_unique<Impl>(options)) {} 1426 lsp::PDLLServer::~PDLLServer() = default; 1427 1428 void lsp::PDLLServer::addOrUpdateDocument( 1429 const URIForFile &uri, StringRef contents, int64_t version, 1430 std::vector<Diagnostic> &diagnostics) { 1431 std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs; 1432 if (auto *fileInfo = impl->compilationDatabase.getFileInfo(uri.file())) 1433 llvm::append_range(additionalIncludeDirs, fileInfo->includeDirs); 1434 1435 impl->files[uri.file()] = std::make_unique<PDLTextFile>( 1436 uri, contents, version, additionalIncludeDirs, diagnostics); 1437 } 1438 1439 Optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) { 1440 auto it = impl->files.find(uri.file()); 1441 if (it == impl->files.end()) 1442 return llvm::None; 1443 1444 int64_t version = it->second->getVersion(); 1445 impl->files.erase(it); 1446 return version; 1447 } 1448 1449 void lsp::PDLLServer::getLocationsOf(const URIForFile &uri, 1450 const Position &defPos, 1451 std::vector<Location> &locations) { 1452 auto fileIt = impl->files.find(uri.file()); 1453 if (fileIt != impl->files.end()) 1454 fileIt->second->getLocationsOf(uri, defPos, locations); 1455 } 1456 1457 void lsp::PDLLServer::findReferencesOf(const URIForFile &uri, 1458 const Position &pos, 1459 std::vector<Location> &references) { 1460 auto fileIt = impl->files.find(uri.file()); 1461 if (fileIt != impl->files.end()) 1462 fileIt->second->findReferencesOf(uri, pos, references); 1463 } 1464 1465 void lsp::PDLLServer::getDocumentLinks( 1466 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) { 1467 auto fileIt = impl->files.find(uri.file()); 1468 if (fileIt != impl->files.end()) 1469 return fileIt->second->getDocumentLinks(uri, documentLinks); 1470 } 1471 1472 Optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri, 1473 const Position &hoverPos) { 1474 auto fileIt = impl->files.find(uri.file()); 1475 if (fileIt != impl->files.end()) 1476 return fileIt->second->findHover(uri, hoverPos); 1477 return llvm::None; 1478 } 1479 1480 void lsp::PDLLServer::findDocumentSymbols( 1481 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) { 1482 auto fileIt = impl->files.find(uri.file()); 1483 if (fileIt != impl->files.end()) 1484 fileIt->second->findDocumentSymbols(symbols); 1485 } 1486 1487 lsp::CompletionList 1488 lsp::PDLLServer::getCodeCompletion(const URIForFile &uri, 1489 const Position &completePos) { 1490 auto fileIt = impl->files.find(uri.file()); 1491 if (fileIt != impl->files.end()) 1492 return fileIt->second->getCodeCompletion(uri, completePos); 1493 return CompletionList(); 1494 } 1495 1496 lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri, 1497 const Position &helpPos) { 1498 auto fileIt = impl->files.find(uri.file()); 1499 if (fileIt != impl->files.end()) 1500 return fileIt->second->getSignatureHelp(uri, helpPos); 1501 return SignatureHelp(); 1502 } 1503