1 //===- ExtractAPI/ExtractAPIConsumer.cpp ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the ExtractAPIAction, and ASTVisitor/Consumer to
11 /// collect API information.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "TypedefUnderlyingTypeResolver.h"
16 #include "clang/AST/ASTConsumer.h"
17 #include "clang/AST/ASTContext.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/AST/DeclCXX.h"
20 #include "clang/AST/ParentMapContext.h"
21 #include "clang/AST/RawCommentList.h"
22 #include "clang/AST/RecursiveASTVisitor.h"
23 #include "clang/Basic/TargetInfo.h"
24 #include "clang/ExtractAPI/API.h"
25 #include "clang/ExtractAPI/AvailabilityInfo.h"
26 #include "clang/ExtractAPI/DeclarationFragments.h"
27 #include "clang/ExtractAPI/FrontendActions.h"
28 #include "clang/ExtractAPI/Serialization/SymbolGraphSerializer.h"
29 #include "clang/Frontend/ASTConsumers.h"
30 #include "clang/Frontend/CompilerInstance.h"
31 #include "clang/Frontend/FrontendOptions.h"
32 #include "clang/Lex/MacroInfo.h"
33 #include "clang/Lex/PPCallbacks.h"
34 #include "clang/Lex/PreprocessorOptions.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/Support/MemoryBuffer.h"
38 #include "llvm/Support/raw_ostream.h"
39 
40 using namespace clang;
41 using namespace extractapi;
42 
43 namespace {
44 
45 StringRef getTypedefName(const TagDecl *Decl) {
46   if (const auto *TypedefDecl = Decl->getTypedefNameForAnonDecl())
47     return TypedefDecl->getName();
48 
49   return {};
50 }
51 
52 /// The RecursiveASTVisitor to traverse symbol declarations and collect API
53 /// information.
54 class ExtractAPIVisitor : public RecursiveASTVisitor<ExtractAPIVisitor> {
55 public:
56   ExtractAPIVisitor(ASTContext &Context, APISet &API)
57       : Context(Context), API(API) {}
58 
59   const APISet &getAPI() const { return API; }
60 
61   bool VisitVarDecl(const VarDecl *Decl) {
62     // Skip function parameters.
63     if (isa<ParmVarDecl>(Decl))
64       return true;
65 
66     // Skip non-global variables in records (struct/union/class).
67     if (Decl->getDeclContext()->isRecord())
68       return true;
69 
70     // Skip local variables inside function or method.
71     if (!Decl->isDefinedOutsideFunctionOrMethod())
72       return true;
73 
74     // If this is a template but not specialization or instantiation, skip.
75     if (Decl->getASTContext().getTemplateOrSpecializationInfo(Decl) &&
76         Decl->getTemplateSpecializationKind() == TSK_Undeclared)
77       return true;
78 
79     // Collect symbol information.
80     StringRef Name = Decl->getName();
81     StringRef USR = API.recordUSR(Decl);
82     PresumedLoc Loc =
83         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
84     AvailabilityInfo Availability = getAvailability(Decl);
85     LinkageInfo Linkage = Decl->getLinkageAndVisibility();
86     DocComment Comment;
87     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
88       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
89                                               Context.getDiagnostics());
90 
91     // Build declaration fragments and sub-heading for the variable.
92     DeclarationFragments Declaration =
93         DeclarationFragmentsBuilder::getFragmentsForVar(Decl);
94     DeclarationFragments SubHeading =
95         DeclarationFragmentsBuilder::getSubHeading(Decl);
96 
97     // Add the global variable record to the API set.
98     API.addGlobalVar(Name, USR, Loc, Availability, Linkage, Comment,
99                      Declaration, SubHeading);
100     return true;
101   }
102 
103   bool VisitFunctionDecl(const FunctionDecl *Decl) {
104     if (const auto *Method = dyn_cast<CXXMethodDecl>(Decl)) {
105       // Skip member function in class templates.
106       if (Method->getParent()->getDescribedClassTemplate() != nullptr)
107         return true;
108 
109       // Skip methods in records.
110       for (auto P : Context.getParents(*Method)) {
111         if (P.get<CXXRecordDecl>())
112           return true;
113       }
114 
115       // Skip ConstructorDecl and DestructorDecl.
116       if (isa<CXXConstructorDecl>(Method) || isa<CXXDestructorDecl>(Method))
117         return true;
118     }
119 
120     // Skip templated functions.
121     switch (Decl->getTemplatedKind()) {
122     case FunctionDecl::TK_NonTemplate:
123       break;
124     case FunctionDecl::TK_MemberSpecialization:
125     case FunctionDecl::TK_FunctionTemplateSpecialization:
126       if (auto *TemplateInfo = Decl->getTemplateSpecializationInfo()) {
127         if (!TemplateInfo->isExplicitInstantiationOrSpecialization())
128           return true;
129       }
130       break;
131     case FunctionDecl::TK_FunctionTemplate:
132     case FunctionDecl::TK_DependentFunctionTemplateSpecialization:
133       return true;
134     }
135 
136     // Collect symbol information.
137     StringRef Name = Decl->getName();
138     StringRef USR = API.recordUSR(Decl);
139     PresumedLoc Loc =
140         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
141     AvailabilityInfo Availability = getAvailability(Decl);
142     LinkageInfo Linkage = Decl->getLinkageAndVisibility();
143     DocComment Comment;
144     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
145       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
146                                               Context.getDiagnostics());
147 
148     // Build declaration fragments, sub-heading, and signature of the function.
149     DeclarationFragments Declaration =
150         DeclarationFragmentsBuilder::getFragmentsForFunction(Decl);
151     DeclarationFragments SubHeading =
152         DeclarationFragmentsBuilder::getSubHeading(Decl);
153     FunctionSignature Signature =
154         DeclarationFragmentsBuilder::getFunctionSignature(Decl);
155 
156     // Add the function record to the API set.
157     API.addFunction(Name, USR, Loc, Availability, Linkage, Comment, Declaration,
158                     SubHeading, Signature);
159     return true;
160   }
161 
162   bool VisitEnumDecl(const EnumDecl *Decl) {
163     if (!Decl->isComplete())
164       return true;
165 
166     // Skip forward declaration.
167     if (!Decl->isThisDeclarationADefinition())
168       return true;
169 
170     // Collect symbol information.
171     StringRef Name = Decl->getName();
172     if (Name.empty())
173       Name = getTypedefName(Decl);
174     StringRef USR = API.recordUSR(Decl);
175     PresumedLoc Loc =
176         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
177     AvailabilityInfo Availability = getAvailability(Decl);
178     DocComment Comment;
179     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
180       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
181                                               Context.getDiagnostics());
182 
183     // Build declaration fragments and sub-heading for the enum.
184     DeclarationFragments Declaration =
185         DeclarationFragmentsBuilder::getFragmentsForEnum(Decl);
186     DeclarationFragments SubHeading =
187         DeclarationFragmentsBuilder::getSubHeading(Decl);
188 
189     EnumRecord *EnumRecord = API.addEnum(Name, USR, Loc, Availability, Comment,
190                                          Declaration, SubHeading);
191 
192     // Now collect information about the enumerators in this enum.
193     recordEnumConstants(EnumRecord, Decl->enumerators());
194 
195     return true;
196   }
197 
198   bool VisitRecordDecl(const RecordDecl *Decl) {
199     if (!Decl->isCompleteDefinition())
200       return true;
201 
202     // Skip C++ structs/classes/unions
203     // TODO: support C++ records
204     if (isa<CXXRecordDecl>(Decl))
205       return true;
206 
207     // Collect symbol information.
208     StringRef Name = Decl->getName();
209     if (Name.empty())
210       Name = getTypedefName(Decl);
211     StringRef USR = API.recordUSR(Decl);
212     PresumedLoc Loc =
213         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
214     AvailabilityInfo Availability = getAvailability(Decl);
215     DocComment Comment;
216     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
217       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
218                                               Context.getDiagnostics());
219 
220     // Build declaration fragments and sub-heading for the struct.
221     DeclarationFragments Declaration =
222         DeclarationFragmentsBuilder::getFragmentsForStruct(Decl);
223     DeclarationFragments SubHeading =
224         DeclarationFragmentsBuilder::getSubHeading(Decl);
225 
226     StructRecord *StructRecord = API.addStruct(
227         Name, USR, Loc, Availability, Comment, Declaration, SubHeading);
228 
229     // Now collect information about the fields in this struct.
230     recordStructFields(StructRecord, Decl->fields());
231 
232     return true;
233   }
234 
235   bool VisitObjCInterfaceDecl(const ObjCInterfaceDecl *Decl) {
236     // Skip forward declaration for classes (@class)
237     if (!Decl->isThisDeclarationADefinition())
238       return true;
239 
240     // Collect symbol information.
241     StringRef Name = Decl->getName();
242     StringRef USR = API.recordUSR(Decl);
243     PresumedLoc Loc =
244         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
245     AvailabilityInfo Availability = getAvailability(Decl);
246     LinkageInfo Linkage = Decl->getLinkageAndVisibility();
247     DocComment Comment;
248     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
249       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
250                                               Context.getDiagnostics());
251 
252     // Build declaration fragments and sub-heading for the interface.
253     DeclarationFragments Declaration =
254         DeclarationFragmentsBuilder::getFragmentsForObjCInterface(Decl);
255     DeclarationFragments SubHeading =
256         DeclarationFragmentsBuilder::getSubHeading(Decl);
257 
258     // Collect super class information.
259     SymbolReference SuperClass;
260     if (const auto *SuperClassDecl = Decl->getSuperClass()) {
261       SuperClass.Name = SuperClassDecl->getObjCRuntimeNameAsString();
262       SuperClass.USR = API.recordUSR(SuperClassDecl);
263     }
264 
265     ObjCInterfaceRecord *ObjCInterfaceRecord =
266         API.addObjCInterface(Name, USR, Loc, Availability, Linkage, Comment,
267                              Declaration, SubHeading, SuperClass);
268 
269     // Record all methods (selectors). This doesn't include automatically
270     // synthesized property methods.
271     recordObjCMethods(ObjCInterfaceRecord, Decl->methods());
272     recordObjCProperties(ObjCInterfaceRecord, Decl->properties());
273     recordObjCInstanceVariables(ObjCInterfaceRecord, Decl->ivars());
274     recordObjCProtocols(ObjCInterfaceRecord, Decl->protocols());
275 
276     return true;
277   }
278 
279   bool VisitObjCProtocolDecl(const ObjCProtocolDecl *Decl) {
280     // Skip forward declaration for protocols (@protocol).
281     if (!Decl->isThisDeclarationADefinition())
282       return true;
283 
284     // Collect symbol information.
285     StringRef Name = Decl->getName();
286     StringRef USR = API.recordUSR(Decl);
287     PresumedLoc Loc =
288         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
289     AvailabilityInfo Availability = getAvailability(Decl);
290     DocComment Comment;
291     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
292       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
293                                               Context.getDiagnostics());
294 
295     // Build declaration fragments and sub-heading for the protocol.
296     DeclarationFragments Declaration =
297         DeclarationFragmentsBuilder::getFragmentsForObjCProtocol(Decl);
298     DeclarationFragments SubHeading =
299         DeclarationFragmentsBuilder::getSubHeading(Decl);
300 
301     ObjCProtocolRecord *ObjCProtocolRecord = API.addObjCProtocol(
302         Name, USR, Loc, Availability, Comment, Declaration, SubHeading);
303 
304     recordObjCMethods(ObjCProtocolRecord, Decl->methods());
305     recordObjCProperties(ObjCProtocolRecord, Decl->properties());
306     recordObjCProtocols(ObjCProtocolRecord, Decl->protocols());
307 
308     return true;
309   }
310 
311   bool VisitTypedefNameDecl(const TypedefNameDecl *Decl) {
312     // Skip ObjC Type Parameter for now.
313     if (isa<ObjCTypeParamDecl>(Decl))
314       return true;
315 
316     if (!Decl->isDefinedOutsideFunctionOrMethod())
317       return true;
318 
319     PresumedLoc Loc =
320         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
321     StringRef Name = Decl->getName();
322     AvailabilityInfo Availability = getAvailability(Decl);
323     StringRef USR = API.recordUSR(Decl);
324     DocComment Comment;
325     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
326       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
327                                               Context.getDiagnostics());
328 
329     QualType Type = Decl->getUnderlyingType();
330     SymbolReference SymRef =
331         TypedefUnderlyingTypeResolver(Context).getSymbolReferenceForType(Type,
332                                                                          API);
333 
334     API.addTypedef(Name, USR, Loc, Availability, Comment,
335                    DeclarationFragmentsBuilder::getFragmentsForTypedef(Decl),
336                    DeclarationFragmentsBuilder::getSubHeading(Decl), SymRef);
337 
338     return true;
339   }
340 
341   bool VisitObjCCategoryDecl(const ObjCCategoryDecl *Decl) {
342     // Collect symbol information.
343     StringRef Name = Decl->getName();
344     StringRef USR = API.recordUSR(Decl);
345     PresumedLoc Loc =
346         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
347     AvailabilityInfo Availability = getAvailability(Decl);
348     DocComment Comment;
349     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
350       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
351                                               Context.getDiagnostics());
352     // Build declaration fragments and sub-heading for the category.
353     DeclarationFragments Declaration =
354         DeclarationFragmentsBuilder::getFragmentsForObjCCategory(Decl);
355     DeclarationFragments SubHeading =
356         DeclarationFragmentsBuilder::getSubHeading(Decl);
357 
358     const ObjCInterfaceDecl *InterfaceDecl = Decl->getClassInterface();
359     SymbolReference Interface(InterfaceDecl->getName(),
360                               API.recordUSR(InterfaceDecl));
361 
362     ObjCCategoryRecord *ObjCCategoryRecord =
363         API.addObjCCategory(Name, USR, Loc, Availability, Comment, Declaration,
364                             SubHeading, Interface);
365 
366     recordObjCMethods(ObjCCategoryRecord, Decl->methods());
367     recordObjCProperties(ObjCCategoryRecord, Decl->properties());
368     recordObjCInstanceVariables(ObjCCategoryRecord, Decl->ivars());
369     recordObjCProtocols(ObjCCategoryRecord, Decl->protocols());
370 
371     return true;
372   }
373 
374 private:
375   /// Get availability information of the declaration \p D.
376   AvailabilityInfo getAvailability(const Decl *D) const {
377     StringRef PlatformName = Context.getTargetInfo().getPlatformName();
378 
379     AvailabilityInfo Availability;
380     // Collect availability attributes from all redeclarations.
381     for (const auto *RD : D->redecls()) {
382       for (const auto *A : RD->specific_attrs<AvailabilityAttr>()) {
383         if (A->getPlatform()->getName() != PlatformName)
384           continue;
385         Availability = AvailabilityInfo(A->getIntroduced(), A->getDeprecated(),
386                                         A->getObsoleted(), A->getUnavailable(),
387                                         /* UnconditionallyDeprecated */ false,
388                                         /* UnconditionallyUnavailable */ false);
389         break;
390       }
391 
392       if (const auto *A = RD->getAttr<UnavailableAttr>())
393         if (!A->isImplicit()) {
394           Availability.Unavailable = true;
395           Availability.UnconditionallyUnavailable = true;
396         }
397 
398       if (const auto *A = RD->getAttr<DeprecatedAttr>())
399         if (!A->isImplicit())
400           Availability.UnconditionallyDeprecated = true;
401     }
402 
403     return Availability;
404   }
405 
406   /// Collect API information for the enum constants and associate with the
407   /// parent enum.
408   void recordEnumConstants(EnumRecord *EnumRecord,
409                            const EnumDecl::enumerator_range Constants) {
410     for (const auto *Constant : Constants) {
411       // Collect symbol information.
412       StringRef Name = Constant->getName();
413       StringRef USR = API.recordUSR(Constant);
414       PresumedLoc Loc =
415           Context.getSourceManager().getPresumedLoc(Constant->getLocation());
416       AvailabilityInfo Availability = getAvailability(Constant);
417       DocComment Comment;
418       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Constant))
419         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
420                                                 Context.getDiagnostics());
421 
422       // Build declaration fragments and sub-heading for the enum constant.
423       DeclarationFragments Declaration =
424           DeclarationFragmentsBuilder::getFragmentsForEnumConstant(Constant);
425       DeclarationFragments SubHeading =
426           DeclarationFragmentsBuilder::getSubHeading(Constant);
427 
428       API.addEnumConstant(EnumRecord, Name, USR, Loc, Availability, Comment,
429                           Declaration, SubHeading);
430     }
431   }
432 
433   /// Collect API information for the struct fields and associate with the
434   /// parent struct.
435   void recordStructFields(StructRecord *StructRecord,
436                           const RecordDecl::field_range Fields) {
437     for (const auto *Field : Fields) {
438       // Collect symbol information.
439       StringRef Name = Field->getName();
440       StringRef USR = API.recordUSR(Field);
441       PresumedLoc Loc =
442           Context.getSourceManager().getPresumedLoc(Field->getLocation());
443       AvailabilityInfo Availability = getAvailability(Field);
444       DocComment Comment;
445       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Field))
446         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
447                                                 Context.getDiagnostics());
448 
449       // Build declaration fragments and sub-heading for the struct field.
450       DeclarationFragments Declaration =
451           DeclarationFragmentsBuilder::getFragmentsForField(Field);
452       DeclarationFragments SubHeading =
453           DeclarationFragmentsBuilder::getSubHeading(Field);
454 
455       API.addStructField(StructRecord, Name, USR, Loc, Availability, Comment,
456                          Declaration, SubHeading);
457     }
458   }
459 
460   /// Collect API information for the Objective-C methods and associate with the
461   /// parent container.
462   void recordObjCMethods(ObjCContainerRecord *Container,
463                          const ObjCContainerDecl::method_range Methods) {
464     for (const auto *Method : Methods) {
465       // Don't record selectors for properties.
466       if (Method->isPropertyAccessor())
467         continue;
468 
469       StringRef Name = API.copyString(Method->getSelector().getAsString());
470       StringRef USR = API.recordUSR(Method);
471       PresumedLoc Loc =
472           Context.getSourceManager().getPresumedLoc(Method->getLocation());
473       AvailabilityInfo Availability = getAvailability(Method);
474       DocComment Comment;
475       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Method))
476         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
477                                                 Context.getDiagnostics());
478 
479       // Build declaration fragments, sub-heading, and signature for the method.
480       DeclarationFragments Declaration =
481           DeclarationFragmentsBuilder::getFragmentsForObjCMethod(Method);
482       DeclarationFragments SubHeading =
483           DeclarationFragmentsBuilder::getSubHeading(Method);
484       FunctionSignature Signature =
485           DeclarationFragmentsBuilder::getFunctionSignature(Method);
486 
487       API.addObjCMethod(Container, Name, USR, Loc, Availability, Comment,
488                         Declaration, SubHeading, Signature,
489                         Method->isInstanceMethod());
490     }
491   }
492 
493   void recordObjCProperties(ObjCContainerRecord *Container,
494                             const ObjCContainerDecl::prop_range Properties) {
495     for (const auto *Property : Properties) {
496       StringRef Name = Property->getName();
497       StringRef USR = API.recordUSR(Property);
498       PresumedLoc Loc =
499           Context.getSourceManager().getPresumedLoc(Property->getLocation());
500       AvailabilityInfo Availability = getAvailability(Property);
501       DocComment Comment;
502       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Property))
503         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
504                                                 Context.getDiagnostics());
505 
506       // Build declaration fragments and sub-heading for the property.
507       DeclarationFragments Declaration =
508           DeclarationFragmentsBuilder::getFragmentsForObjCProperty(Property);
509       DeclarationFragments SubHeading =
510           DeclarationFragmentsBuilder::getSubHeading(Property);
511 
512       StringRef GetterName =
513           API.copyString(Property->getGetterName().getAsString());
514       StringRef SetterName =
515           API.copyString(Property->getSetterName().getAsString());
516 
517       // Get the attributes for property.
518       unsigned Attributes = ObjCPropertyRecord::NoAttr;
519       if (Property->getPropertyAttributes() &
520           ObjCPropertyAttribute::kind_readonly)
521         Attributes |= ObjCPropertyRecord::ReadOnly;
522       if (Property->getPropertyAttributes() & ObjCPropertyAttribute::kind_class)
523         Attributes |= ObjCPropertyRecord::Class;
524 
525       API.addObjCProperty(
526           Container, Name, USR, Loc, Availability, Comment, Declaration,
527           SubHeading,
528           static_cast<ObjCPropertyRecord::AttributeKind>(Attributes),
529           GetterName, SetterName, Property->isOptional());
530     }
531   }
532 
533   void recordObjCInstanceVariables(
534       ObjCContainerRecord *Container,
535       const llvm::iterator_range<
536           DeclContext::specific_decl_iterator<ObjCIvarDecl>>
537           Ivars) {
538     for (const auto *Ivar : Ivars) {
539       StringRef Name = Ivar->getName();
540       StringRef USR = API.recordUSR(Ivar);
541       PresumedLoc Loc =
542           Context.getSourceManager().getPresumedLoc(Ivar->getLocation());
543       AvailabilityInfo Availability = getAvailability(Ivar);
544       DocComment Comment;
545       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Ivar))
546         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
547                                                 Context.getDiagnostics());
548 
549       // Build declaration fragments and sub-heading for the instance variable.
550       DeclarationFragments Declaration =
551           DeclarationFragmentsBuilder::getFragmentsForField(Ivar);
552       DeclarationFragments SubHeading =
553           DeclarationFragmentsBuilder::getSubHeading(Ivar);
554 
555       ObjCInstanceVariableRecord::AccessControl Access =
556           Ivar->getCanonicalAccessControl();
557 
558       API.addObjCInstanceVariable(Container, Name, USR, Loc, Availability,
559                                   Comment, Declaration, SubHeading, Access);
560     }
561   }
562 
563   void recordObjCProtocols(ObjCContainerRecord *Container,
564                            ObjCInterfaceDecl::protocol_range Protocols) {
565     for (const auto *Protocol : Protocols)
566       Container->Protocols.emplace_back(Protocol->getName(),
567                                         API.recordUSR(Protocol));
568   }
569 
570   ASTContext &Context;
571   APISet &API;
572 };
573 
574 class ExtractAPIConsumer : public ASTConsumer {
575 public:
576   ExtractAPIConsumer(ASTContext &Context, APISet &API)
577       : Visitor(Context, API) {}
578 
579   void HandleTranslationUnit(ASTContext &Context) override {
580     // Use ExtractAPIVisitor to traverse symbol declarations in the context.
581     Visitor.TraverseDecl(Context.getTranslationUnitDecl());
582   }
583 
584 private:
585   ExtractAPIVisitor Visitor;
586 };
587 
588 class MacroCallback : public PPCallbacks {
589 public:
590   MacroCallback(const SourceManager &SM, APISet &API) : SM(SM), API(API) {}
591 
592   void MacroDefined(const Token &MacroNameToken,
593                     const MacroDirective *MD) override {
594     auto *MacroInfo = MD->getMacroInfo();
595 
596     if (MacroInfo->isBuiltinMacro())
597       return;
598 
599     auto SourceLoc = MacroNameToken.getLocation();
600     if (SM.isWrittenInBuiltinFile(SourceLoc) ||
601         SM.isWrittenInCommandLineFile(SourceLoc))
602       return;
603 
604     PendingMacros.emplace_back(MacroNameToken, MD);
605   }
606 
607   // If a macro gets undefined at some point during preprocessing of the inputs
608   // it means that it isn't an exposed API and we should therefore not add a
609   // macro definition for it.
610   void MacroUndefined(const Token &MacroNameToken, const MacroDefinition &MD,
611                       const MacroDirective *Undef) override {
612     // If this macro wasn't previously defined we don't need to do anything
613     // here.
614     if (!Undef)
615       return;
616 
617     llvm::erase_if(PendingMacros, [&MD](const PendingMacro &PM) {
618       return MD.getMacroInfo()->getDefinitionLoc() ==
619              PM.MD->getMacroInfo()->getDefinitionLoc();
620     });
621   }
622 
623   void EndOfMainFile() override {
624     for (auto &PM : PendingMacros) {
625       // `isUsedForHeaderGuard` is only set when the preprocessor leaves the
626       // file so check for it here.
627       if (PM.MD->getMacroInfo()->isUsedForHeaderGuard())
628         continue;
629 
630       StringRef Name = PM.MacroNameToken.getIdentifierInfo()->getName();
631       PresumedLoc Loc = SM.getPresumedLoc(PM.MacroNameToken.getLocation());
632       StringRef USR =
633           API.recordUSRForMacro(Name, PM.MacroNameToken.getLocation(), SM);
634 
635       API.addMacroDefinition(
636           Name, USR, Loc,
637           DeclarationFragmentsBuilder::getFragmentsForMacro(Name, PM.MD),
638           DeclarationFragmentsBuilder::getSubHeadingForMacro(Name));
639     }
640 
641     PendingMacros.clear();
642   }
643 
644 private:
645   struct PendingMacro {
646     Token MacroNameToken;
647     const MacroDirective *MD;
648 
649     PendingMacro(const Token &MacroNameToken, const MacroDirective *MD)
650         : MacroNameToken(MacroNameToken), MD(MD) {}
651   };
652 
653   const SourceManager &SM;
654   APISet &API;
655   llvm::SmallVector<PendingMacro> PendingMacros;
656 };
657 
658 } // namespace
659 
660 std::unique_ptr<ASTConsumer>
661 ExtractAPIAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
662   OS = CreateOutputFile(CI, InFile);
663   if (!OS)
664     return nullptr;
665 
666   ProductName = CI.getFrontendOpts().ProductName;
667 
668   // Now that we have enough information about the language options and the
669   // target triple, let's create the APISet before anyone uses it.
670   API = std::make_unique<APISet>(
671       CI.getTarget().getTriple(),
672       CI.getFrontendOpts().Inputs.back().getKind().getLanguage());
673 
674   // Register preprocessor callbacks that will add macro definitions to API.
675   CI.getPreprocessor().addPPCallbacks(
676       std::make_unique<MacroCallback>(CI.getSourceManager(), *API));
677 
678   return std::make_unique<ExtractAPIConsumer>(CI.getASTContext(), *API);
679 }
680 
681 bool ExtractAPIAction::PrepareToExecuteAction(CompilerInstance &CI) {
682   auto &Inputs = CI.getFrontendOpts().Inputs;
683   if (Inputs.empty())
684     return true;
685 
686   auto Kind = Inputs[0].getKind();
687 
688   // Convert the header file inputs into a single input buffer.
689   SmallString<256> HeaderContents;
690   for (const FrontendInputFile &FIF : Inputs) {
691     if (Kind.isObjectiveC())
692       HeaderContents += "#import";
693     else
694       HeaderContents += "#include";
695     HeaderContents += " \"";
696     HeaderContents += FIF.getFile();
697     HeaderContents += "\"\n";
698   }
699 
700   Buffer = llvm::MemoryBuffer::getMemBufferCopy(HeaderContents,
701                                                 getInputBufferName());
702 
703   // Set that buffer up as our "real" input in the CompilerInstance.
704   Inputs.clear();
705   Inputs.emplace_back(Buffer->getMemBufferRef(), Kind, /*IsSystem*/ false);
706 
707   return true;
708 }
709 
710 void ExtractAPIAction::EndSourceFileAction() {
711   if (!OS)
712     return;
713 
714   // Setup a SymbolGraphSerializer to write out collected API information in
715   // the Symbol Graph format.
716   // FIXME: Make the kind of APISerializer configurable.
717   SymbolGraphSerializer SGSerializer(*API, ProductName);
718   SGSerializer.serialize(*OS);
719   OS->flush();
720 }
721 
722 std::unique_ptr<raw_pwrite_stream>
723 ExtractAPIAction::CreateOutputFile(CompilerInstance &CI, StringRef InFile) {
724   std::unique_ptr<raw_pwrite_stream> OS =
725       CI.createDefaultOutputFile(/*Binary=*/false, InFile, /*Extension=*/"json",
726                                  /*RemoveFileOnSignal=*/false);
727   if (!OS)
728     return nullptr;
729   return OS;
730 }
731