1 //===--- Quality.cpp ---------------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "Quality.h"
10 #include "AST.h"
11 #include "ASTSignals.h"
12 #include "CompletionModel.h"
13 #include "FileDistance.h"
14 #include "SourceCode.h"
15 #include "index/Symbol.h"
16 #include "clang/AST/ASTContext.h"
17 #include "clang/AST/Decl.h"
18 #include "clang/AST/DeclCXX.h"
19 #include "clang/AST/DeclTemplate.h"
20 #include "clang/AST/DeclVisitor.h"
21 #include "clang/Basic/SourceManager.h"
22 #include "clang/Sema/CodeCompleteConsumer.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/MathExtras.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <algorithm>
29 #include <cmath>
30
31 namespace clang {
32 namespace clangd {
33
hasDeclInMainFile(const Decl & D)34 static bool hasDeclInMainFile(const Decl &D) {
35 auto &SourceMgr = D.getASTContext().getSourceManager();
36 for (auto *Redecl : D.redecls()) {
37 if (isInsideMainFile(Redecl->getLocation(), SourceMgr))
38 return true;
39 }
40 return false;
41 }
42
hasUsingDeclInMainFile(const CodeCompletionResult & R)43 static bool hasUsingDeclInMainFile(const CodeCompletionResult &R) {
44 const auto &Context = R.Declaration->getASTContext();
45 const auto &SourceMgr = Context.getSourceManager();
46 if (R.ShadowDecl) {
47 if (isInsideMainFile(R.ShadowDecl->getLocation(), SourceMgr))
48 return true;
49 }
50 return false;
51 }
52
categorize(const NamedDecl & ND)53 static SymbolQualitySignals::SymbolCategory categorize(const NamedDecl &ND) {
54 if (const auto *FD = dyn_cast<FunctionDecl>(&ND)) {
55 if (FD->isOverloadedOperator())
56 return SymbolQualitySignals::Operator;
57 }
58 class Switch
59 : public ConstDeclVisitor<Switch, SymbolQualitySignals::SymbolCategory> {
60 public:
61 #define MAP(DeclType, Category) \
62 SymbolQualitySignals::SymbolCategory Visit##DeclType(const DeclType *) { \
63 return SymbolQualitySignals::Category; \
64 }
65 MAP(NamespaceDecl, Namespace);
66 MAP(NamespaceAliasDecl, Namespace);
67 MAP(TypeDecl, Type);
68 MAP(TypeAliasTemplateDecl, Type);
69 MAP(ClassTemplateDecl, Type);
70 MAP(CXXConstructorDecl, Constructor);
71 MAP(CXXDestructorDecl, Destructor);
72 MAP(ValueDecl, Variable);
73 MAP(VarTemplateDecl, Variable);
74 MAP(FunctionDecl, Function);
75 MAP(FunctionTemplateDecl, Function);
76 MAP(Decl, Unknown);
77 #undef MAP
78 };
79 return Switch().Visit(&ND);
80 }
81
82 static SymbolQualitySignals::SymbolCategory
categorize(const CodeCompletionResult & R)83 categorize(const CodeCompletionResult &R) {
84 if (R.Declaration)
85 return categorize(*R.Declaration);
86 if (R.Kind == CodeCompletionResult::RK_Macro)
87 return SymbolQualitySignals::Macro;
88 // Everything else is a keyword or a pattern. Patterns are mostly keywords
89 // too, except a few which we recognize by cursor kind.
90 switch (R.CursorKind) {
91 case CXCursor_CXXMethod:
92 return SymbolQualitySignals::Function;
93 case CXCursor_ModuleImportDecl:
94 return SymbolQualitySignals::Namespace;
95 case CXCursor_MacroDefinition:
96 return SymbolQualitySignals::Macro;
97 case CXCursor_TypeRef:
98 return SymbolQualitySignals::Type;
99 case CXCursor_MemberRef:
100 return SymbolQualitySignals::Variable;
101 case CXCursor_Constructor:
102 return SymbolQualitySignals::Constructor;
103 default:
104 return SymbolQualitySignals::Keyword;
105 }
106 }
107
108 static SymbolQualitySignals::SymbolCategory
categorize(const index::SymbolInfo & D)109 categorize(const index::SymbolInfo &D) {
110 switch (D.Kind) {
111 case index::SymbolKind::Namespace:
112 case index::SymbolKind::NamespaceAlias:
113 return SymbolQualitySignals::Namespace;
114 case index::SymbolKind::Macro:
115 return SymbolQualitySignals::Macro;
116 case index::SymbolKind::Enum:
117 case index::SymbolKind::Struct:
118 case index::SymbolKind::Class:
119 case index::SymbolKind::Protocol:
120 case index::SymbolKind::Extension:
121 case index::SymbolKind::Union:
122 case index::SymbolKind::TypeAlias:
123 case index::SymbolKind::TemplateTypeParm:
124 case index::SymbolKind::TemplateTemplateParm:
125 case index::SymbolKind::Concept:
126 return SymbolQualitySignals::Type;
127 case index::SymbolKind::Function:
128 case index::SymbolKind::ClassMethod:
129 case index::SymbolKind::InstanceMethod:
130 case index::SymbolKind::StaticMethod:
131 case index::SymbolKind::InstanceProperty:
132 case index::SymbolKind::ClassProperty:
133 case index::SymbolKind::StaticProperty:
134 case index::SymbolKind::ConversionFunction:
135 return SymbolQualitySignals::Function;
136 case index::SymbolKind::Destructor:
137 return SymbolQualitySignals::Destructor;
138 case index::SymbolKind::Constructor:
139 return SymbolQualitySignals::Constructor;
140 case index::SymbolKind::Variable:
141 case index::SymbolKind::Field:
142 case index::SymbolKind::EnumConstant:
143 case index::SymbolKind::Parameter:
144 case index::SymbolKind::NonTypeTemplateParm:
145 return SymbolQualitySignals::Variable;
146 case index::SymbolKind::Using:
147 case index::SymbolKind::Module:
148 case index::SymbolKind::Unknown:
149 return SymbolQualitySignals::Unknown;
150 }
151 llvm_unreachable("Unknown index::SymbolKind");
152 }
153
isInstanceMember(const NamedDecl * ND)154 static bool isInstanceMember(const NamedDecl *ND) {
155 if (!ND)
156 return false;
157 if (const auto *TP = dyn_cast<FunctionTemplateDecl>(ND))
158 ND = TP->TemplateDecl::getTemplatedDecl();
159 if (const auto *CM = dyn_cast<CXXMethodDecl>(ND))
160 return !CM->isStatic();
161 return isa<FieldDecl>(ND); // Note that static fields are VarDecl.
162 }
163
isInstanceMember(const index::SymbolInfo & D)164 static bool isInstanceMember(const index::SymbolInfo &D) {
165 switch (D.Kind) {
166 case index::SymbolKind::InstanceMethod:
167 case index::SymbolKind::InstanceProperty:
168 case index::SymbolKind::Field:
169 return true;
170 default:
171 return false;
172 }
173 }
174
merge(const CodeCompletionResult & SemaCCResult)175 void SymbolQualitySignals::merge(const CodeCompletionResult &SemaCCResult) {
176 Deprecated |= (SemaCCResult.Availability == CXAvailability_Deprecated);
177 Category = categorize(SemaCCResult);
178
179 if (SemaCCResult.Declaration) {
180 ImplementationDetail |= isImplementationDetail(SemaCCResult.Declaration);
181 if (auto *ID = SemaCCResult.Declaration->getIdentifier())
182 ReservedName = ReservedName || isReservedName(ID->getName());
183 } else if (SemaCCResult.Kind == CodeCompletionResult::RK_Macro)
184 ReservedName =
185 ReservedName || isReservedName(SemaCCResult.Macro->getName());
186 }
187
merge(const Symbol & IndexResult)188 void SymbolQualitySignals::merge(const Symbol &IndexResult) {
189 Deprecated |= (IndexResult.Flags & Symbol::Deprecated);
190 ImplementationDetail |= (IndexResult.Flags & Symbol::ImplementationDetail);
191 References = std::max(IndexResult.References, References);
192 Category = categorize(IndexResult.SymInfo);
193 ReservedName = ReservedName || isReservedName(IndexResult.Name);
194 }
195
evaluateHeuristics() const196 float SymbolQualitySignals::evaluateHeuristics() const {
197 float Score = 1;
198
199 // This avoids a sharp gradient for tail symbols, and also neatly avoids the
200 // question of whether 0 references means a bad symbol or missing data.
201 if (References >= 10) {
202 // Use a sigmoid style boosting function, which flats out nicely for large
203 // numbers (e.g. 2.58 for 1M references).
204 // The following boosting function is equivalent to:
205 // m = 0.06
206 // f = 12.0
207 // boost = f * sigmoid(m * std::log(References)) - 0.5 * f + 0.59
208 // Sample data points: (10, 1.00), (100, 1.41), (1000, 1.82),
209 // (10K, 2.21), (100K, 2.58), (1M, 2.94)
210 float S = std::pow(References, -0.06);
211 Score *= 6.0 * (1 - S) / (1 + S) + 0.59;
212 }
213
214 if (Deprecated)
215 Score *= 0.1f;
216 if (ReservedName)
217 Score *= 0.1f;
218 if (ImplementationDetail)
219 Score *= 0.2f;
220
221 switch (Category) {
222 case Keyword: // Often relevant, but misses most signals.
223 Score *= 4; // FIXME: important keywords should have specific boosts.
224 break;
225 case Type:
226 case Function:
227 case Variable:
228 Score *= 1.1f;
229 break;
230 case Namespace:
231 Score *= 0.8f;
232 break;
233 case Macro:
234 case Destructor:
235 case Operator:
236 Score *= 0.5f;
237 break;
238 case Constructor: // No boost constructors so they are after class types.
239 case Unknown:
240 break;
241 }
242
243 return Score;
244 }
245
operator <<(llvm::raw_ostream & OS,const SymbolQualitySignals & S)246 llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
247 const SymbolQualitySignals &S) {
248 OS << llvm::formatv("=== Symbol quality: {0}\n", S.evaluateHeuristics());
249 OS << llvm::formatv("\tReferences: {0}\n", S.References);
250 OS << llvm::formatv("\tDeprecated: {0}\n", S.Deprecated);
251 OS << llvm::formatv("\tReserved name: {0}\n", S.ReservedName);
252 OS << llvm::formatv("\tImplementation detail: {0}\n", S.ImplementationDetail);
253 OS << llvm::formatv("\tCategory: {0}\n", static_cast<int>(S.Category));
254 return OS;
255 }
256
257 static SymbolRelevanceSignals::AccessibleScope
computeScope(const NamedDecl * D)258 computeScope(const NamedDecl *D) {
259 // Injected "Foo" within the class "Foo" has file scope, not class scope.
260 const DeclContext *DC = D->getDeclContext();
261 if (auto *R = dyn_cast_or_null<RecordDecl>(D))
262 if (R->isInjectedClassName())
263 DC = DC->getParent();
264 // Class constructor should have the same scope as the class.
265 if (isa<CXXConstructorDecl>(D))
266 DC = DC->getParent();
267 bool InClass = false;
268 for (; !DC->isFileContext(); DC = DC->getParent()) {
269 if (DC->isFunctionOrMethod())
270 return SymbolRelevanceSignals::FunctionScope;
271 InClass = InClass || DC->isRecord();
272 }
273 if (InClass)
274 return SymbolRelevanceSignals::ClassScope;
275 // ExternalLinkage threshold could be tweaked, e.g. module-visible as global.
276 // Avoid caching linkage if it may change after enclosing code completion.
277 if (hasUnstableLinkage(D) || D->getLinkageInternal() < ExternalLinkage)
278 return SymbolRelevanceSignals::FileScope;
279 return SymbolRelevanceSignals::GlobalScope;
280 }
281
merge(const Symbol & IndexResult)282 void SymbolRelevanceSignals::merge(const Symbol &IndexResult) {
283 SymbolURI = IndexResult.CanonicalDeclaration.FileURI;
284 SymbolScope = IndexResult.Scope;
285 IsInstanceMember |= isInstanceMember(IndexResult.SymInfo);
286 if (!(IndexResult.Flags & Symbol::VisibleOutsideFile)) {
287 Scope = AccessibleScope::FileScope;
288 }
289 if (MainFileSignals) {
290 MainFileRefs =
291 std::max(MainFileRefs,
292 MainFileSignals->ReferencedSymbols.lookup(IndexResult.ID));
293 ScopeRefsInFile =
294 std::max(ScopeRefsInFile,
295 MainFileSignals->RelatedNamespaces.lookup(IndexResult.Scope));
296 }
297 }
298
computeASTSignals(const CodeCompletionResult & SemaResult)299 void SymbolRelevanceSignals::computeASTSignals(
300 const CodeCompletionResult &SemaResult) {
301 if (!MainFileSignals)
302 return;
303 if ((SemaResult.Kind != CodeCompletionResult::RK_Declaration) &&
304 (SemaResult.Kind != CodeCompletionResult::RK_Pattern))
305 return;
306 if (const NamedDecl *ND = SemaResult.getDeclaration()) {
307 auto ID = getSymbolID(ND);
308 if (!ID)
309 return;
310 MainFileRefs =
311 std::max(MainFileRefs, MainFileSignals->ReferencedSymbols.lookup(ID));
312 if (const auto *NSD = dyn_cast<NamespaceDecl>(ND->getDeclContext())) {
313 if (NSD->isAnonymousNamespace())
314 return;
315 std::string Scope = printNamespaceScope(*NSD);
316 if (!Scope.empty())
317 ScopeRefsInFile = std::max(
318 ScopeRefsInFile, MainFileSignals->RelatedNamespaces.lookup(Scope));
319 }
320 }
321 }
322
merge(const CodeCompletionResult & SemaCCResult)323 void SymbolRelevanceSignals::merge(const CodeCompletionResult &SemaCCResult) {
324 if (SemaCCResult.Availability == CXAvailability_NotAvailable ||
325 SemaCCResult.Availability == CXAvailability_NotAccessible)
326 Forbidden = true;
327
328 if (SemaCCResult.Declaration) {
329 SemaSaysInScope = true;
330 // We boost things that have decls in the main file. We give a fixed score
331 // for all other declarations in sema as they are already included in the
332 // translation unit.
333 float DeclProximity = (hasDeclInMainFile(*SemaCCResult.Declaration) ||
334 hasUsingDeclInMainFile(SemaCCResult))
335 ? 1.0
336 : 0.6;
337 SemaFileProximityScore = std::max(DeclProximity, SemaFileProximityScore);
338 IsInstanceMember |= isInstanceMember(SemaCCResult.Declaration);
339 InBaseClass |= SemaCCResult.InBaseClass;
340 }
341
342 computeASTSignals(SemaCCResult);
343 // Declarations are scoped, others (like macros) are assumed global.
344 if (SemaCCResult.Declaration)
345 Scope = std::min(Scope, computeScope(SemaCCResult.Declaration));
346
347 NeedsFixIts = !SemaCCResult.FixIts.empty();
348 }
349
fileProximityScore(unsigned FileDistance)350 static float fileProximityScore(unsigned FileDistance) {
351 // Range: [0, 1]
352 // FileDistance = [0, 1, 2, 3, 4, .., FileDistance::Unreachable]
353 // Score = [1, 0.82, 0.67, 0.55, 0.45, .., 0]
354 if (FileDistance == FileDistance::Unreachable)
355 return 0;
356 // Assume approximately default options are used for sensible scoring.
357 return std::exp(FileDistance * -0.4f / FileDistanceOptions().UpCost);
358 }
359
scopeProximityScore(unsigned ScopeDistance)360 static float scopeProximityScore(unsigned ScopeDistance) {
361 // Range: [0.6, 2].
362 // ScopeDistance = [0, 1, 2, 3, 4, 5, 6, 7, .., FileDistance::Unreachable]
363 // Score = [2.0, 1.55, 1.2, 0.93, 0.72, 0.65, 0.65, 0.65, .., 0.6]
364 if (ScopeDistance == FileDistance::Unreachable)
365 return 0.6f;
366 return std::max(0.65, 2.0 * std::pow(0.6, ScopeDistance / 2.0));
367 }
368
369 static llvm::Optional<llvm::StringRef>
wordMatching(llvm::StringRef Name,const llvm::StringSet<> * ContextWords)370 wordMatching(llvm::StringRef Name, const llvm::StringSet<> *ContextWords) {
371 if (ContextWords)
372 for (const auto &Word : ContextWords->keys())
373 if (Name.contains_insensitive(Word))
374 return Word;
375 return llvm::None;
376 }
377
378 SymbolRelevanceSignals::DerivedSignals
calculateDerivedSignals() const379 SymbolRelevanceSignals::calculateDerivedSignals() const {
380 DerivedSignals Derived;
381 Derived.NameMatchesContext = wordMatching(Name, ContextWords).has_value();
382 Derived.FileProximityDistance = !FileProximityMatch || SymbolURI.empty()
383 ? FileDistance::Unreachable
384 : FileProximityMatch->distance(SymbolURI);
385 if (ScopeProximityMatch) {
386 // For global symbol, the distance is 0.
387 Derived.ScopeProximityDistance =
388 SymbolScope ? ScopeProximityMatch->distance(*SymbolScope) : 0;
389 }
390 return Derived;
391 }
392
evaluateHeuristics() const393 float SymbolRelevanceSignals::evaluateHeuristics() const {
394 DerivedSignals Derived = calculateDerivedSignals();
395 float Score = 1;
396
397 if (Forbidden)
398 return 0;
399
400 Score *= NameMatch;
401
402 // File proximity scores are [0,1] and we translate them into a multiplier in
403 // the range from 1 to 3.
404 Score *= 1 + 2 * std::max(fileProximityScore(Derived.FileProximityDistance),
405 SemaFileProximityScore);
406
407 if (ScopeProximityMatch)
408 // Use a constant scope boost for sema results, as scopes of sema results
409 // can be tricky (e.g. class/function scope). Set to the max boost as we
410 // don't load top-level symbols from the preamble and sema results are
411 // always in the accessible scope.
412 Score *= SemaSaysInScope
413 ? 2.0
414 : scopeProximityScore(Derived.ScopeProximityDistance);
415
416 if (Derived.NameMatchesContext)
417 Score *= 1.5;
418
419 // Symbols like local variables may only be referenced within their scope.
420 // Conversely if we're in that scope, it's likely we'll reference them.
421 if (Query == CodeComplete) {
422 // The narrower the scope where a symbol is visible, the more likely it is
423 // to be relevant when it is available.
424 switch (Scope) {
425 case GlobalScope:
426 break;
427 case FileScope:
428 Score *= 1.5f;
429 break;
430 case ClassScope:
431 Score *= 2;
432 break;
433 case FunctionScope:
434 Score *= 4;
435 break;
436 }
437 } else {
438 // For non-completion queries, the wider the scope where a symbol is
439 // visible, the more likely it is to be relevant.
440 switch (Scope) {
441 case GlobalScope:
442 break;
443 case FileScope:
444 Score *= 0.5f;
445 break;
446 default:
447 // TODO: Handle other scopes as we start to use them for index results.
448 break;
449 }
450 }
451
452 if (TypeMatchesPreferred)
453 Score *= 5.0;
454
455 // Penalize non-instance members when they are accessed via a class instance.
456 if (!IsInstanceMember &&
457 (Context == CodeCompletionContext::CCC_DotMemberAccess ||
458 Context == CodeCompletionContext::CCC_ArrowMemberAccess)) {
459 Score *= 0.2f;
460 }
461
462 if (InBaseClass)
463 Score *= 0.5f;
464
465 // Penalize for FixIts.
466 if (NeedsFixIts)
467 Score *= 0.5f;
468
469 // Use a sigmoid style boosting function similar to `References`, which flats
470 // out nicely for large values. This avoids a sharp gradient for heavily
471 // referenced symbols. Use smaller gradient for ScopeRefsInFile since ideally
472 // MainFileRefs <= ScopeRefsInFile.
473 if (MainFileRefs >= 2) {
474 // E.g.: (2, 1.12), (9, 2.0), (48, 3.0).
475 float S = std::pow(MainFileRefs, -0.11);
476 Score *= 11.0 * (1 - S) / (1 + S) + 0.7;
477 }
478 if (ScopeRefsInFile >= 2) {
479 // E.g.: (2, 1.04), (14, 2.0), (109, 3.0), (400, 3.6).
480 float S = std::pow(ScopeRefsInFile, -0.10);
481 Score *= 10.0 * (1 - S) / (1 + S) + 0.7;
482 }
483
484 return Score;
485 }
486
operator <<(llvm::raw_ostream & OS,const SymbolRelevanceSignals & S)487 llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
488 const SymbolRelevanceSignals &S) {
489 OS << llvm::formatv("=== Symbol relevance: {0}\n", S.evaluateHeuristics());
490 OS << llvm::formatv("\tName: {0}\n", S.Name);
491 OS << llvm::formatv("\tName match: {0}\n", S.NameMatch);
492 if (S.ContextWords)
493 OS << llvm::formatv(
494 "\tMatching context word: {0}\n",
495 wordMatching(S.Name, S.ContextWords).value_or("<none>"));
496 OS << llvm::formatv("\tForbidden: {0}\n", S.Forbidden);
497 OS << llvm::formatv("\tNeedsFixIts: {0}\n", S.NeedsFixIts);
498 OS << llvm::formatv("\tIsInstanceMember: {0}\n", S.IsInstanceMember);
499 OS << llvm::formatv("\tInBaseClass: {0}\n", S.InBaseClass);
500 OS << llvm::formatv("\tContext: {0}\n", getCompletionKindString(S.Context));
501 OS << llvm::formatv("\tQuery type: {0}\n", static_cast<int>(S.Query));
502 OS << llvm::formatv("\tScope: {0}\n", static_cast<int>(S.Scope));
503
504 OS << llvm::formatv("\tSymbol URI: {0}\n", S.SymbolURI);
505 OS << llvm::formatv("\tSymbol scope: {0}\n",
506 S.SymbolScope ? *S.SymbolScope : "<None>");
507
508 SymbolRelevanceSignals::DerivedSignals Derived = S.calculateDerivedSignals();
509 if (S.FileProximityMatch) {
510 unsigned Score = fileProximityScore(Derived.FileProximityDistance);
511 OS << llvm::formatv("\tIndex URI proximity: {0} (distance={1})\n", Score,
512 Derived.FileProximityDistance);
513 }
514 OS << llvm::formatv("\tSema file proximity: {0}\n", S.SemaFileProximityScore);
515
516 OS << llvm::formatv("\tSema says in scope: {0}\n", S.SemaSaysInScope);
517 if (S.ScopeProximityMatch)
518 OS << llvm::formatv("\tIndex scope boost: {0}\n",
519 scopeProximityScore(Derived.ScopeProximityDistance));
520
521 OS << llvm::formatv(
522 "\tType matched preferred: {0} (Context type: {1}, Symbol type: {2}\n",
523 S.TypeMatchesPreferred, S.HadContextType, S.HadSymbolType);
524
525 return OS;
526 }
527
evaluateSymbolAndRelevance(float SymbolQuality,float SymbolRelevance)528 float evaluateSymbolAndRelevance(float SymbolQuality, float SymbolRelevance) {
529 return SymbolQuality * SymbolRelevance;
530 }
531
532 DecisionForestScores
evaluateDecisionForest(const SymbolQualitySignals & Quality,const SymbolRelevanceSignals & Relevance,float Base)533 evaluateDecisionForest(const SymbolQualitySignals &Quality,
534 const SymbolRelevanceSignals &Relevance, float Base) {
535 Example E;
536 E.setIsDeprecated(Quality.Deprecated);
537 E.setIsReservedName(Quality.ReservedName);
538 E.setIsImplementationDetail(Quality.ImplementationDetail);
539 E.setNumReferences(Quality.References);
540 E.setSymbolCategory(Quality.Category);
541
542 SymbolRelevanceSignals::DerivedSignals Derived =
543 Relevance.calculateDerivedSignals();
544 int NumMatch = 0;
545 if (Relevance.ContextWords) {
546 for (const auto &Word : Relevance.ContextWords->keys()) {
547 if (Relevance.Name.contains_insensitive(Word)) {
548 ++NumMatch;
549 }
550 }
551 }
552 E.setIsNameInContext(NumMatch > 0);
553 E.setNumNameInContext(NumMatch);
554 E.setFractionNameInContext(
555 Relevance.ContextWords && !Relevance.ContextWords->empty()
556 ? NumMatch * 1.0 / Relevance.ContextWords->size()
557 : 0);
558 E.setIsInBaseClass(Relevance.InBaseClass);
559 E.setFileProximityDistanceCost(Derived.FileProximityDistance);
560 E.setSemaFileProximityScore(Relevance.SemaFileProximityScore);
561 E.setSymbolScopeDistanceCost(Derived.ScopeProximityDistance);
562 E.setSemaSaysInScope(Relevance.SemaSaysInScope);
563 E.setScope(Relevance.Scope);
564 E.setContextKind(Relevance.Context);
565 E.setIsInstanceMember(Relevance.IsInstanceMember);
566 E.setHadContextType(Relevance.HadContextType);
567 E.setHadSymbolType(Relevance.HadSymbolType);
568 E.setTypeMatchesPreferred(Relevance.TypeMatchesPreferred);
569
570 DecisionForestScores Scores;
571 // Exponentiating DecisionForest prediction makes the score of each tree a
572 // multiplciative boost (like NameMatch). This allows us to weigh the
573 // prediciton score and NameMatch appropriately.
574 Scores.ExcludingName = pow(Base, Evaluate(E));
575 // Following cases are not part of the generated training dataset:
576 // - Symbols with `NeedsFixIts`.
577 // - Forbidden symbols.
578 // - Keywords: Dataset contains only macros and decls.
579 if (Relevance.NeedsFixIts)
580 Scores.ExcludingName *= 0.5;
581 if (Relevance.Forbidden)
582 Scores.ExcludingName *= 0;
583 if (Quality.Category == SymbolQualitySignals::Keyword)
584 Scores.ExcludingName *= 4;
585
586 // NameMatch should be a multiplier on total score to support rescoring.
587 Scores.Total = Relevance.NameMatch * Scores.ExcludingName;
588 return Scores;
589 }
590
591 // Produces an integer that sorts in the same order as F.
592 // That is: a < b <==> encodeFloat(a) < encodeFloat(b).
encodeFloat(float F)593 static uint32_t encodeFloat(float F) {
594 static_assert(std::numeric_limits<float>::is_iec559, "");
595 constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1);
596
597 // Get the bits of the float. Endianness is the same as for integers.
598 uint32_t U = llvm::FloatToBits(F);
599 // IEEE 754 floats compare like sign-magnitude integers.
600 if (U & TopBit) // Negative float.
601 return 0 - U; // Map onto the low half of integers, order reversed.
602 return U + TopBit; // Positive floats map onto the high half of integers.
603 }
604
sortText(float Score,llvm::StringRef Name)605 std::string sortText(float Score, llvm::StringRef Name) {
606 // We convert -Score to an integer, and hex-encode for readability.
607 // Example: [0.5, "foo"] -> "41000000foo"
608 std::string S;
609 llvm::raw_string_ostream OS(S);
610 llvm::write_hex(OS, encodeFloat(-Score), llvm::HexPrintStyle::Lower,
611 /*Width=*/2 * sizeof(Score));
612 OS << Name;
613 OS.flush();
614 return S;
615 }
616
operator <<(llvm::raw_ostream & OS,const SignatureQualitySignals & S)617 llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
618 const SignatureQualitySignals &S) {
619 OS << llvm::formatv("=== Signature Quality:\n");
620 OS << llvm::formatv("\tNumber of parameters: {0}\n", S.NumberOfParameters);
621 OS << llvm::formatv("\tNumber of optional parameters: {0}\n",
622 S.NumberOfOptionalParameters);
623 OS << llvm::formatv("\tKind: {0}\n", S.Kind);
624 return OS;
625 }
626
627 } // namespace clangd
628 } // namespace clang
629