1 //===- ExtractAPI/ExtractAPIConsumer.cpp ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the ExtractAPIAction, and ASTVisitor/Consumer to
11 /// collect API information.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "TypedefUnderlyingTypeResolver.h"
16 #include "clang/AST/ASTConsumer.h"
17 #include "clang/AST/ASTContext.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/AST/DeclCXX.h"
20 #include "clang/AST/ParentMapContext.h"
21 #include "clang/AST/RawCommentList.h"
22 #include "clang/AST/RecursiveASTVisitor.h"
23 #include "clang/Basic/TargetInfo.h"
24 #include "clang/ExtractAPI/API.h"
25 #include "clang/ExtractAPI/AvailabilityInfo.h"
26 #include "clang/ExtractAPI/DeclarationFragments.h"
27 #include "clang/ExtractAPI/FrontendActions.h"
28 #include "clang/ExtractAPI/Serialization/SymbolGraphSerializer.h"
29 #include "clang/Frontend/ASTConsumers.h"
30 #include "clang/Frontend/CompilerInstance.h"
31 #include "clang/Frontend/FrontendOptions.h"
32 #include "clang/Lex/MacroInfo.h"
33 #include "clang/Lex/PPCallbacks.h"
34 #include "clang/Lex/PreprocessorOptions.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/Support/MemoryBuffer.h"
38 #include "llvm/Support/raw_ostream.h"
39 
40 using namespace clang;
41 using namespace extractapi;
42 
43 namespace {
44 
45 StringRef getTypedefName(const TagDecl *Decl) {
46   if (const auto *TypedefDecl = Decl->getTypedefNameForAnonDecl())
47     return TypedefDecl->getName();
48 
49   return {};
50 }
51 
52 /// The RecursiveASTVisitor to traverse symbol declarations and collect API
53 /// information.
54 class ExtractAPIVisitor : public RecursiveASTVisitor<ExtractAPIVisitor> {
55 public:
56   ExtractAPIVisitor(ASTContext &Context, APISet &API)
57       : Context(Context), API(API) {}
58 
59   const APISet &getAPI() const { return API; }
60 
61   bool VisitVarDecl(const VarDecl *Decl) {
62     // Skip function parameters.
63     if (isa<ParmVarDecl>(Decl))
64       return true;
65 
66     // Skip non-global variables in records (struct/union/class).
67     if (Decl->getDeclContext()->isRecord())
68       return true;
69 
70     // Skip local variables inside function or method.
71     if (!Decl->isDefinedOutsideFunctionOrMethod())
72       return true;
73 
74     // If this is a template but not specialization or instantiation, skip.
75     if (Decl->getASTContext().getTemplateOrSpecializationInfo(Decl) &&
76         Decl->getTemplateSpecializationKind() == TSK_Undeclared)
77       return true;
78 
79     // Collect symbol information.
80     StringRef Name = Decl->getName();
81     StringRef USR = API.recordUSR(Decl);
82     PresumedLoc Loc =
83         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
84     AvailabilityInfo Availability = getAvailability(Decl);
85     LinkageInfo Linkage = Decl->getLinkageAndVisibility();
86     DocComment Comment;
87     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
88       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
89                                               Context.getDiagnostics());
90 
91     // Build declaration fragments and sub-heading for the variable.
92     DeclarationFragments Declaration =
93         DeclarationFragmentsBuilder::getFragmentsForVar(Decl);
94     DeclarationFragments SubHeading =
95         DeclarationFragmentsBuilder::getSubHeading(Decl);
96 
97     // Add the global variable record to the API set.
98     API.addGlobalVar(Name, USR, Loc, Availability, Linkage, Comment,
99                      Declaration, SubHeading);
100     return true;
101   }
102 
103   bool VisitFunctionDecl(const FunctionDecl *Decl) {
104     if (const auto *Method = dyn_cast<CXXMethodDecl>(Decl)) {
105       // Skip member function in class templates.
106       if (Method->getParent()->getDescribedClassTemplate() != nullptr)
107         return true;
108 
109       // Skip methods in records.
110       for (auto P : Context.getParents(*Method)) {
111         if (P.get<CXXRecordDecl>())
112           return true;
113       }
114 
115       // Skip ConstructorDecl and DestructorDecl.
116       if (isa<CXXConstructorDecl>(Method) || isa<CXXDestructorDecl>(Method))
117         return true;
118     }
119 
120     // Skip templated functions.
121     switch (Decl->getTemplatedKind()) {
122     case FunctionDecl::TK_NonTemplate:
123       break;
124     case FunctionDecl::TK_MemberSpecialization:
125     case FunctionDecl::TK_FunctionTemplateSpecialization:
126       if (auto *TemplateInfo = Decl->getTemplateSpecializationInfo()) {
127         if (!TemplateInfo->isExplicitInstantiationOrSpecialization())
128           return true;
129       }
130       break;
131     case FunctionDecl::TK_FunctionTemplate:
132     case FunctionDecl::TK_DependentFunctionTemplateSpecialization:
133       return true;
134     }
135 
136     // Collect symbol information.
137     StringRef Name = Decl->getName();
138     StringRef USR = API.recordUSR(Decl);
139     PresumedLoc Loc =
140         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
141     AvailabilityInfo Availability = getAvailability(Decl);
142     LinkageInfo Linkage = Decl->getLinkageAndVisibility();
143     DocComment Comment;
144     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
145       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
146                                               Context.getDiagnostics());
147 
148     // Build declaration fragments, sub-heading, and signature of the function.
149     DeclarationFragments Declaration =
150         DeclarationFragmentsBuilder::getFragmentsForFunction(Decl);
151     DeclarationFragments SubHeading =
152         DeclarationFragmentsBuilder::getSubHeading(Decl);
153     FunctionSignature Signature =
154         DeclarationFragmentsBuilder::getFunctionSignature(Decl);
155 
156     // Add the function record to the API set.
157     API.addFunction(Name, USR, Loc, Availability, Linkage, Comment, Declaration,
158                     SubHeading, Signature);
159     return true;
160   }
161 
162   bool VisitEnumDecl(const EnumDecl *Decl) {
163     if (!Decl->isComplete())
164       return true;
165 
166     // Skip forward declaration.
167     if (!Decl->isThisDeclarationADefinition())
168       return true;
169 
170     // Collect symbol information.
171     StringRef Name = Decl->getName();
172     if (Name.empty())
173       Name = getTypedefName(Decl);
174     StringRef USR = API.recordUSR(Decl);
175     PresumedLoc Loc =
176         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
177     AvailabilityInfo Availability = getAvailability(Decl);
178     DocComment Comment;
179     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
180       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
181                                               Context.getDiagnostics());
182 
183     // Build declaration fragments and sub-heading for the enum.
184     DeclarationFragments Declaration =
185         DeclarationFragmentsBuilder::getFragmentsForEnum(Decl);
186     DeclarationFragments SubHeading =
187         DeclarationFragmentsBuilder::getSubHeading(Decl);
188 
189     EnumRecord *EnumRecord = API.addEnum(Name, USR, Loc, Availability, Comment,
190                                          Declaration, SubHeading);
191 
192     // Now collect information about the enumerators in this enum.
193     recordEnumConstants(EnumRecord, Decl->enumerators());
194 
195     return true;
196   }
197 
198   bool VisitRecordDecl(const RecordDecl *Decl) {
199     if (!Decl->isCompleteDefinition())
200       return true;
201 
202     // Skip C++ structs/classes/unions
203     // TODO: support C++ records
204     if (isa<CXXRecordDecl>(Decl))
205       return true;
206 
207     // Collect symbol information.
208     StringRef Name = Decl->getName();
209     if (Name.empty())
210       Name = getTypedefName(Decl);
211     StringRef USR = API.recordUSR(Decl);
212     PresumedLoc Loc =
213         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
214     AvailabilityInfo Availability = getAvailability(Decl);
215     DocComment Comment;
216     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
217       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
218                                               Context.getDiagnostics());
219 
220     // Build declaration fragments and sub-heading for the struct.
221     DeclarationFragments Declaration =
222         DeclarationFragmentsBuilder::getFragmentsForStruct(Decl);
223     DeclarationFragments SubHeading =
224         DeclarationFragmentsBuilder::getSubHeading(Decl);
225 
226     StructRecord *StructRecord = API.addStruct(
227         Name, USR, Loc, Availability, Comment, Declaration, SubHeading);
228 
229     // Now collect information about the fields in this struct.
230     recordStructFields(StructRecord, Decl->fields());
231 
232     return true;
233   }
234 
235   bool VisitObjCInterfaceDecl(const ObjCInterfaceDecl *Decl) {
236     // Skip forward declaration for classes (@class)
237     if (!Decl->isThisDeclarationADefinition())
238       return true;
239 
240     // Collect symbol information.
241     StringRef Name = Decl->getName();
242     StringRef USR = API.recordUSR(Decl);
243     PresumedLoc Loc =
244         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
245     AvailabilityInfo Availability = getAvailability(Decl);
246     LinkageInfo Linkage = Decl->getLinkageAndVisibility();
247     DocComment Comment;
248     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
249       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
250                                               Context.getDiagnostics());
251 
252     // Build declaration fragments and sub-heading for the interface.
253     DeclarationFragments Declaration =
254         DeclarationFragmentsBuilder::getFragmentsForObjCInterface(Decl);
255     DeclarationFragments SubHeading =
256         DeclarationFragmentsBuilder::getSubHeading(Decl);
257 
258     // Collect super class information.
259     SymbolReference SuperClass;
260     if (const auto *SuperClassDecl = Decl->getSuperClass()) {
261       SuperClass.Name = SuperClassDecl->getObjCRuntimeNameAsString();
262       SuperClass.USR = API.recordUSR(SuperClassDecl);
263     }
264 
265     ObjCInterfaceRecord *ObjCInterfaceRecord =
266         API.addObjCInterface(Name, USR, Loc, Availability, Linkage, Comment,
267                              Declaration, SubHeading, SuperClass);
268 
269     // Record all methods (selectors). This doesn't include automatically
270     // synthesized property methods.
271     recordObjCMethods(ObjCInterfaceRecord, Decl->methods());
272     recordObjCProperties(ObjCInterfaceRecord, Decl->properties());
273     recordObjCInstanceVariables(ObjCInterfaceRecord, Decl->ivars());
274     recordObjCProtocols(ObjCInterfaceRecord, Decl->protocols());
275 
276     return true;
277   }
278 
279   bool VisitObjCProtocolDecl(const ObjCProtocolDecl *Decl) {
280     // Skip forward declaration for protocols (@protocol).
281     if (!Decl->isThisDeclarationADefinition())
282       return true;
283 
284     // Collect symbol information.
285     StringRef Name = Decl->getName();
286     StringRef USR = API.recordUSR(Decl);
287     PresumedLoc Loc =
288         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
289     AvailabilityInfo Availability = getAvailability(Decl);
290     DocComment Comment;
291     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
292       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
293                                               Context.getDiagnostics());
294 
295     // Build declaration fragments and sub-heading for the protocol.
296     DeclarationFragments Declaration =
297         DeclarationFragmentsBuilder::getFragmentsForObjCProtocol(Decl);
298     DeclarationFragments SubHeading =
299         DeclarationFragmentsBuilder::getSubHeading(Decl);
300 
301     ObjCProtocolRecord *ObjCProtocolRecord = API.addObjCProtocol(
302         Name, USR, Loc, Availability, Comment, Declaration, SubHeading);
303 
304     recordObjCMethods(ObjCProtocolRecord, Decl->methods());
305     recordObjCProperties(ObjCProtocolRecord, Decl->properties());
306     recordObjCProtocols(ObjCProtocolRecord, Decl->protocols());
307 
308     return true;
309   }
310 
311   bool VisitTypedefNameDecl(const TypedefNameDecl *Decl) {
312     // Skip ObjC Type Parameter for now.
313     if (isa<ObjCTypeParamDecl>(Decl))
314       return true;
315 
316     if (!Decl->isDefinedOutsideFunctionOrMethod())
317       return true;
318 
319     PresumedLoc Loc =
320         Context.getSourceManager().getPresumedLoc(Decl->getLocation());
321     StringRef Name = Decl->getName();
322     AvailabilityInfo Availability = getAvailability(Decl);
323     StringRef USR = API.recordUSR(Decl);
324     DocComment Comment;
325     if (auto *RawComment = Context.getRawCommentForDeclNoCache(Decl))
326       Comment = RawComment->getFormattedLines(Context.getSourceManager(),
327                                               Context.getDiagnostics());
328 
329     QualType Type = Decl->getUnderlyingType();
330     SymbolReference SymRef =
331         TypedefUnderlyingTypeResolver(Context).getSymbolReferenceForType(Type,
332                                                                          API);
333 
334     API.addTypedef(Name, USR, Loc, Availability, Comment,
335                    DeclarationFragmentsBuilder::getFragmentsForTypedef(Decl),
336                    DeclarationFragmentsBuilder::getSubHeading(Decl), SymRef);
337 
338     return true;
339   }
340 
341 private:
342   /// Get availability information of the declaration \p D.
343   AvailabilityInfo getAvailability(const Decl *D) const {
344     StringRef PlatformName = Context.getTargetInfo().getPlatformName();
345 
346     AvailabilityInfo Availability;
347     // Collect availability attributes from all redeclarations.
348     for (const auto *RD : D->redecls()) {
349       for (const auto *A : RD->specific_attrs<AvailabilityAttr>()) {
350         if (A->getPlatform()->getName() != PlatformName)
351           continue;
352         Availability = AvailabilityInfo(A->getIntroduced(), A->getDeprecated(),
353                                         A->getObsoleted(), A->getUnavailable(),
354                                         /* UnconditionallyDeprecated */ false,
355                                         /* UnconditionallyUnavailable */ false);
356         break;
357       }
358 
359       if (const auto *A = RD->getAttr<UnavailableAttr>())
360         if (!A->isImplicit()) {
361           Availability.Unavailable = true;
362           Availability.UnconditionallyUnavailable = true;
363         }
364 
365       if (const auto *A = RD->getAttr<DeprecatedAttr>())
366         if (!A->isImplicit())
367           Availability.UnconditionallyDeprecated = true;
368     }
369 
370     return Availability;
371   }
372 
373   /// Collect API information for the enum constants and associate with the
374   /// parent enum.
375   void recordEnumConstants(EnumRecord *EnumRecord,
376                            const EnumDecl::enumerator_range Constants) {
377     for (const auto *Constant : Constants) {
378       // Collect symbol information.
379       StringRef Name = Constant->getName();
380       StringRef USR = API.recordUSR(Constant);
381       PresumedLoc Loc =
382           Context.getSourceManager().getPresumedLoc(Constant->getLocation());
383       AvailabilityInfo Availability = getAvailability(Constant);
384       DocComment Comment;
385       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Constant))
386         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
387                                                 Context.getDiagnostics());
388 
389       // Build declaration fragments and sub-heading for the enum constant.
390       DeclarationFragments Declaration =
391           DeclarationFragmentsBuilder::getFragmentsForEnumConstant(Constant);
392       DeclarationFragments SubHeading =
393           DeclarationFragmentsBuilder::getSubHeading(Constant);
394 
395       API.addEnumConstant(EnumRecord, Name, USR, Loc, Availability, Comment,
396                           Declaration, SubHeading);
397     }
398   }
399 
400   /// Collect API information for the struct fields and associate with the
401   /// parent struct.
402   void recordStructFields(StructRecord *StructRecord,
403                           const RecordDecl::field_range Fields) {
404     for (const auto *Field : Fields) {
405       // Collect symbol information.
406       StringRef Name = Field->getName();
407       StringRef USR = API.recordUSR(Field);
408       PresumedLoc Loc =
409           Context.getSourceManager().getPresumedLoc(Field->getLocation());
410       AvailabilityInfo Availability = getAvailability(Field);
411       DocComment Comment;
412       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Field))
413         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
414                                                 Context.getDiagnostics());
415 
416       // Build declaration fragments and sub-heading for the struct field.
417       DeclarationFragments Declaration =
418           DeclarationFragmentsBuilder::getFragmentsForField(Field);
419       DeclarationFragments SubHeading =
420           DeclarationFragmentsBuilder::getSubHeading(Field);
421 
422       API.addStructField(StructRecord, Name, USR, Loc, Availability, Comment,
423                          Declaration, SubHeading);
424     }
425   }
426 
427   /// Collect API information for the Objective-C methods and associate with the
428   /// parent container.
429   void recordObjCMethods(ObjCContainerRecord *Container,
430                          const ObjCContainerDecl::method_range Methods) {
431     for (const auto *Method : Methods) {
432       // Don't record selectors for properties.
433       if (Method->isPropertyAccessor())
434         continue;
435 
436       StringRef Name = API.copyString(Method->getSelector().getAsString());
437       StringRef USR = API.recordUSR(Method);
438       PresumedLoc Loc =
439           Context.getSourceManager().getPresumedLoc(Method->getLocation());
440       AvailabilityInfo Availability = getAvailability(Method);
441       DocComment Comment;
442       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Method))
443         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
444                                                 Context.getDiagnostics());
445 
446       // Build declaration fragments, sub-heading, and signature for the method.
447       DeclarationFragments Declaration =
448           DeclarationFragmentsBuilder::getFragmentsForObjCMethod(Method);
449       DeclarationFragments SubHeading =
450           DeclarationFragmentsBuilder::getSubHeading(Method);
451       FunctionSignature Signature =
452           DeclarationFragmentsBuilder::getFunctionSignature(Method);
453 
454       API.addObjCMethod(Container, Name, USR, Loc, Availability, Comment,
455                         Declaration, SubHeading, Signature,
456                         Method->isInstanceMethod());
457     }
458   }
459 
460   void recordObjCProperties(ObjCContainerRecord *Container,
461                             const ObjCContainerDecl::prop_range Properties) {
462     for (const auto *Property : Properties) {
463       StringRef Name = Property->getName();
464       StringRef USR = API.recordUSR(Property);
465       PresumedLoc Loc =
466           Context.getSourceManager().getPresumedLoc(Property->getLocation());
467       AvailabilityInfo Availability = getAvailability(Property);
468       DocComment Comment;
469       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Property))
470         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
471                                                 Context.getDiagnostics());
472 
473       // Build declaration fragments and sub-heading for the property.
474       DeclarationFragments Declaration =
475           DeclarationFragmentsBuilder::getFragmentsForObjCProperty(Property);
476       DeclarationFragments SubHeading =
477           DeclarationFragmentsBuilder::getSubHeading(Property);
478 
479       StringRef GetterName =
480           API.copyString(Property->getGetterName().getAsString());
481       StringRef SetterName =
482           API.copyString(Property->getSetterName().getAsString());
483 
484       // Get the attributes for property.
485       unsigned Attributes = ObjCPropertyRecord::NoAttr;
486       if (Property->getPropertyAttributes() &
487           ObjCPropertyAttribute::kind_readonly)
488         Attributes |= ObjCPropertyRecord::ReadOnly;
489       if (Property->getPropertyAttributes() & ObjCPropertyAttribute::kind_class)
490         Attributes |= ObjCPropertyRecord::Class;
491 
492       API.addObjCProperty(
493           Container, Name, USR, Loc, Availability, Comment, Declaration,
494           SubHeading,
495           static_cast<ObjCPropertyRecord::AttributeKind>(Attributes),
496           GetterName, SetterName, Property->isOptional());
497     }
498   }
499 
500   void recordObjCInstanceVariables(
501       ObjCContainerRecord *Container,
502       const llvm::iterator_range<
503           DeclContext::specific_decl_iterator<ObjCIvarDecl>>
504           Ivars) {
505     for (const auto *Ivar : Ivars) {
506       StringRef Name = Ivar->getName();
507       StringRef USR = API.recordUSR(Ivar);
508       PresumedLoc Loc =
509           Context.getSourceManager().getPresumedLoc(Ivar->getLocation());
510       AvailabilityInfo Availability = getAvailability(Ivar);
511       DocComment Comment;
512       if (auto *RawComment = Context.getRawCommentForDeclNoCache(Ivar))
513         Comment = RawComment->getFormattedLines(Context.getSourceManager(),
514                                                 Context.getDiagnostics());
515 
516       // Build declaration fragments and sub-heading for the instance variable.
517       DeclarationFragments Declaration =
518           DeclarationFragmentsBuilder::getFragmentsForField(Ivar);
519       DeclarationFragments SubHeading =
520           DeclarationFragmentsBuilder::getSubHeading(Ivar);
521 
522       ObjCInstanceVariableRecord::AccessControl Access =
523           Ivar->getCanonicalAccessControl();
524 
525       API.addObjCInstanceVariable(Container, Name, USR, Loc, Availability,
526                                   Comment, Declaration, SubHeading, Access);
527     }
528   }
529 
530   void recordObjCProtocols(ObjCContainerRecord *Container,
531                            ObjCInterfaceDecl::protocol_range Protocols) {
532     for (const auto *Protocol : Protocols)
533       Container->Protocols.emplace_back(Protocol->getName(),
534                                         API.recordUSR(Protocol));
535   }
536 
537   ASTContext &Context;
538   APISet &API;
539 };
540 
541 class ExtractAPIConsumer : public ASTConsumer {
542 public:
543   ExtractAPIConsumer(ASTContext &Context, APISet &API)
544       : Visitor(Context, API) {}
545 
546   void HandleTranslationUnit(ASTContext &Context) override {
547     // Use ExtractAPIVisitor to traverse symbol declarations in the context.
548     Visitor.TraverseDecl(Context.getTranslationUnitDecl());
549   }
550 
551 private:
552   ExtractAPIVisitor Visitor;
553 };
554 
555 class MacroCallback : public PPCallbacks {
556 public:
557   MacroCallback(const SourceManager &SM, APISet &API) : SM(SM), API(API) {}
558 
559   void MacroDefined(const Token &MacroNameToken,
560                     const MacroDirective *MD) override {
561     auto *MacroInfo = MD->getMacroInfo();
562 
563     if (MacroInfo->isBuiltinMacro())
564       return;
565 
566     auto SourceLoc = MacroNameToken.getLocation();
567     if (SM.isWrittenInBuiltinFile(SourceLoc) ||
568         SM.isWrittenInCommandLineFile(SourceLoc))
569       return;
570 
571     PendingMacros.emplace_back(MacroNameToken, MD);
572   }
573 
574   // If a macro gets undefined at some point during preprocessing of the inputs
575   // it means that it isn't an exposed API and we should therefore not add a
576   // macro definition for it.
577   void MacroUndefined(const Token &MacroNameToken, const MacroDefinition &MD,
578                       const MacroDirective *Undef) override {
579     // If this macro wasn't previously defined we don't need to do anything
580     // here.
581     if (!Undef)
582       return;
583 
584     llvm::erase_if(PendingMacros, [&MD](const PendingMacro &PM) {
585       return MD.getMacroInfo()->getDefinitionLoc() ==
586              PM.MD->getMacroInfo()->getDefinitionLoc();
587     });
588   }
589 
590   void EndOfMainFile() override {
591     for (auto &PM : PendingMacros) {
592       // `isUsedForHeaderGuard` is only set when the preprocessor leaves the
593       // file so check for it here.
594       if (PM.MD->getMacroInfo()->isUsedForHeaderGuard())
595         continue;
596 
597       StringRef Name = PM.MacroNameToken.getIdentifierInfo()->getName();
598       PresumedLoc Loc = SM.getPresumedLoc(PM.MacroNameToken.getLocation());
599       StringRef USR =
600           API.recordUSRForMacro(Name, PM.MacroNameToken.getLocation(), SM);
601 
602       API.addMacroDefinition(
603           Name, USR, Loc,
604           DeclarationFragmentsBuilder::getFragmentsForMacro(Name, PM.MD),
605           DeclarationFragmentsBuilder::getSubHeadingForMacro(Name));
606     }
607 
608     PendingMacros.clear();
609   }
610 
611 private:
612   struct PendingMacro {
613     Token MacroNameToken;
614     const MacroDirective *MD;
615 
616     PendingMacro(const Token &MacroNameToken, const MacroDirective *MD)
617         : MacroNameToken(MacroNameToken), MD(MD) {}
618   };
619 
620   const SourceManager &SM;
621   APISet &API;
622   llvm::SmallVector<PendingMacro> PendingMacros;
623 };
624 
625 } // namespace
626 
627 std::unique_ptr<ASTConsumer>
628 ExtractAPIAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
629   OS = CreateOutputFile(CI, InFile);
630   if (!OS)
631     return nullptr;
632 
633   ProductName = CI.getFrontendOpts().ProductName;
634 
635   // Now that we have enough information about the language options and the
636   // target triple, let's create the APISet before anyone uses it.
637   API = std::make_unique<APISet>(
638       CI.getTarget().getTriple(),
639       CI.getFrontendOpts().Inputs.back().getKind().getLanguage());
640 
641   // Register preprocessor callbacks that will add macro definitions to API.
642   CI.getPreprocessor().addPPCallbacks(
643       std::make_unique<MacroCallback>(CI.getSourceManager(), *API));
644 
645   return std::make_unique<ExtractAPIConsumer>(CI.getASTContext(), *API);
646 }
647 
648 bool ExtractAPIAction::PrepareToExecuteAction(CompilerInstance &CI) {
649   auto &Inputs = CI.getFrontendOpts().Inputs;
650   if (Inputs.empty())
651     return true;
652 
653   auto Kind = Inputs[0].getKind();
654 
655   // Convert the header file inputs into a single input buffer.
656   SmallString<256> HeaderContents;
657   for (const FrontendInputFile &FIF : Inputs) {
658     if (Kind.isObjectiveC())
659       HeaderContents += "#import";
660     else
661       HeaderContents += "#include";
662     HeaderContents += " \"";
663     HeaderContents += FIF.getFile();
664     HeaderContents += "\"\n";
665   }
666 
667   Buffer = llvm::MemoryBuffer::getMemBufferCopy(HeaderContents,
668                                                 getInputBufferName());
669 
670   // Set that buffer up as our "real" input in the CompilerInstance.
671   Inputs.clear();
672   Inputs.emplace_back(Buffer->getMemBufferRef(), Kind, /*IsSystem*/ false);
673 
674   return true;
675 }
676 
677 void ExtractAPIAction::EndSourceFileAction() {
678   if (!OS)
679     return;
680 
681   // Setup a SymbolGraphSerializer to write out collected API information in
682   // the Symbol Graph format.
683   // FIXME: Make the kind of APISerializer configurable.
684   SymbolGraphSerializer SGSerializer(*API, ProductName);
685   SGSerializer.serialize(*OS);
686   OS->flush();
687 }
688 
689 std::unique_ptr<raw_pwrite_stream>
690 ExtractAPIAction::CreateOutputFile(CompilerInstance &CI, StringRef InFile) {
691   std::unique_ptr<raw_pwrite_stream> OS =
692       CI.createDefaultOutputFile(/*Binary=*/false, InFile, /*Extension=*/"json",
693                                  /*RemoveFileOnSignal=*/false);
694   if (!OS)
695     return nullptr;
696   return OS;
697 }
698