142418abaSMehdi Amini //===- FunctionImport.cpp - ThinLTO Summary-based Function Import ---------===//
242418abaSMehdi Amini //
342418abaSMehdi Amini //                     The LLVM Compiler Infrastructure
442418abaSMehdi Amini //
542418abaSMehdi Amini // This file is distributed under the University of Illinois Open Source
642418abaSMehdi Amini // License. See LICENSE.TXT for details.
742418abaSMehdi Amini //
842418abaSMehdi Amini //===----------------------------------------------------------------------===//
942418abaSMehdi Amini //
1042418abaSMehdi Amini // This file implements Function import based on summaries.
1142418abaSMehdi Amini //
1242418abaSMehdi Amini //===----------------------------------------------------------------------===//
1342418abaSMehdi Amini 
1442418abaSMehdi Amini #include "llvm/Transforms/IPO/FunctionImport.h"
1542418abaSMehdi Amini 
1642418abaSMehdi Amini #include "llvm/ADT/StringSet.h"
1742418abaSMehdi Amini #include "llvm/IR/AutoUpgrade.h"
1842418abaSMehdi Amini #include "llvm/IR/DiagnosticPrinter.h"
1942418abaSMehdi Amini #include "llvm/IR/IntrinsicInst.h"
2042418abaSMehdi Amini #include "llvm/IR/Module.h"
2142418abaSMehdi Amini #include "llvm/IRReader/IRReader.h"
2242418abaSMehdi Amini #include "llvm/Linker/Linker.h"
2342418abaSMehdi Amini #include "llvm/Object/FunctionIndexObjectFile.h"
2442418abaSMehdi Amini #include "llvm/Support/CommandLine.h"
2542418abaSMehdi Amini #include "llvm/Support/Debug.h"
2642418abaSMehdi Amini #include "llvm/Support/SourceMgr.h"
2742418abaSMehdi Amini using namespace llvm;
2842418abaSMehdi Amini 
2942418abaSMehdi Amini #define DEBUG_TYPE "function-import"
3042418abaSMehdi Amini 
3142418abaSMehdi Amini // Load lazily a module from \p FileName in \p Context.
3242418abaSMehdi Amini static std::unique_ptr<Module> loadFile(const std::string &FileName,
3342418abaSMehdi Amini                                         LLVMContext &Context) {
3442418abaSMehdi Amini   SMDiagnostic Err;
3542418abaSMehdi Amini   DEBUG(dbgs() << "Loading '" << FileName << "'\n");
3642418abaSMehdi Amini   std::unique_ptr<Module> Result = getLazyIRFileModule(FileName, Err, Context);
3742418abaSMehdi Amini   if (!Result) {
3842418abaSMehdi Amini     Err.print("function-import", errs());
3942418abaSMehdi Amini     return nullptr;
4042418abaSMehdi Amini   }
4142418abaSMehdi Amini 
4242418abaSMehdi Amini   Result->materializeMetadata();
4342418abaSMehdi Amini   UpgradeDebugInfo(*Result);
4442418abaSMehdi Amini 
4542418abaSMehdi Amini   return Result;
4642418abaSMehdi Amini }
4742418abaSMehdi Amini 
4842418abaSMehdi Amini // Get a Module for \p FileName from the cache, or load it lazily.
4942418abaSMehdi Amini Module &FunctionImporter::getOrLoadModule(StringRef FileName) {
5042418abaSMehdi Amini   auto &Module = ModuleMap[FileName];
5142418abaSMehdi Amini   if (!Module)
5242418abaSMehdi Amini     Module = loadFile(FileName, Context);
5342418abaSMehdi Amini   return *Module;
5442418abaSMehdi Amini }
5542418abaSMehdi Amini 
5642418abaSMehdi Amini // Automatically import functions in Module \p M based on the summaries index.
5742418abaSMehdi Amini //
5842418abaSMehdi Amini // The current implementation imports every called functions that exists in the
5942418abaSMehdi Amini // summaries index.
6042418abaSMehdi Amini bool FunctionImporter::importFunctions(Module &M) {
6142418abaSMehdi Amini   assert(&Context == &M.getContext());
6242418abaSMehdi Amini 
6342418abaSMehdi Amini   bool Changed = false;
6442418abaSMehdi Amini 
6542418abaSMehdi Amini   /// First step is collecting the called functions and the one defined in this
6642418abaSMehdi Amini   /// module.
6742418abaSMehdi Amini   StringSet<> CalledFunctions;
6842418abaSMehdi Amini   for (auto &F : M) {
6942418abaSMehdi Amini     if (F.isDeclaration() || F.hasFnAttribute(Attribute::OptimizeNone))
7042418abaSMehdi Amini       continue;
7142418abaSMehdi Amini     for (auto &BB : F) {
7242418abaSMehdi Amini       for (auto &I : BB) {
7342418abaSMehdi Amini         if (isa<CallInst>(I)) {
7442418abaSMehdi Amini           DEBUG(dbgs() << "Found a call: '" << I << "'\n");
7542418abaSMehdi Amini           auto CalledFunction = cast<CallInst>(I).getCalledFunction();
7642418abaSMehdi Amini           if (CalledFunction && CalledFunction->hasName() &&
7742418abaSMehdi Amini               CalledFunction->isDeclaration())
7842418abaSMehdi Amini             CalledFunctions.insert(CalledFunction->getName());
7942418abaSMehdi Amini         }
8042418abaSMehdi Amini       }
8142418abaSMehdi Amini     }
8242418abaSMehdi Amini   }
8342418abaSMehdi Amini 
8442418abaSMehdi Amini   /// Second step: for every call to an external function, try to import it.
8542418abaSMehdi Amini 
8642418abaSMehdi Amini   // Linker that will be used for importing function
8742418abaSMehdi Amini   Linker L(&M, DiagnosticHandler);
8842418abaSMehdi Amini 
8942418abaSMehdi Amini   /// Insert initial called function set in a worklist, so that we can add
9042418abaSMehdi Amini   /// transively called functions when importing.
9142418abaSMehdi Amini   SmallVector<StringRef, 64> Worklist;
9242418abaSMehdi Amini   for (auto &CalledFunction : CalledFunctions)
9342418abaSMehdi Amini     Worklist.push_back(CalledFunction.first());
9442418abaSMehdi Amini 
9542418abaSMehdi Amini   while (!Worklist.empty()) {
9642418abaSMehdi Amini     auto CalledFunctionName = Worklist.pop_back_val();
9742418abaSMehdi Amini     DEBUG(dbgs() << "Process import for " << CalledFunctionName << "\n");
9842418abaSMehdi Amini 
9942418abaSMehdi Amini     // Try to get a summary for this function call.
10042418abaSMehdi Amini     auto InfoList = Index.findFunctionInfoList(CalledFunctionName);
10142418abaSMehdi Amini     if (InfoList == Index.end()) {
10242418abaSMehdi Amini       DEBUG(dbgs() << "No summary for " << CalledFunctionName
10342418abaSMehdi Amini                    << " Ignoring.\n");
10442418abaSMehdi Amini       continue;
10542418abaSMehdi Amini     }
10642418abaSMehdi Amini     assert(!InfoList->second.empty() && "No summary, error at import?");
10742418abaSMehdi Amini 
10842418abaSMehdi Amini     // Comdat can have multiple entries, FIXME: what do we do with them?
10942418abaSMehdi Amini     auto &Info = InfoList->second[0];
11042418abaSMehdi Amini     assert(Info && "Nullptr in list, error importing summaries?\n");
11142418abaSMehdi Amini 
11242418abaSMehdi Amini     auto *Summary = Info->functionSummary();
11342418abaSMehdi Amini     if (!Summary) {
11442418abaSMehdi Amini       // FIXME: in case we are lazyloading summaries, we can do it now.
11542418abaSMehdi Amini       dbgs() << "Missing summary for  " << CalledFunctionName
11642418abaSMehdi Amini              << ", error at import?\n";
11742418abaSMehdi Amini       llvm_unreachable("Missing summary");
11842418abaSMehdi Amini     }
11942418abaSMehdi Amini 
12042418abaSMehdi Amini     //
12142418abaSMehdi Amini     // No profitability notion right now, just import all the time...
12242418abaSMehdi Amini     //
12342418abaSMehdi Amini 
12442418abaSMehdi Amini     // Get the module path from the summary.
12542418abaSMehdi Amini     auto FileName = Summary->modulePath();
12642418abaSMehdi Amini     DEBUG(dbgs() << "Importing " << CalledFunctionName << " from " << FileName
12742418abaSMehdi Amini                  << "\n");
12842418abaSMehdi Amini 
12942418abaSMehdi Amini     // Get the module for the import (potentially from the cache).
13042418abaSMehdi Amini     auto &Module = getOrLoadModule(FileName);
13142418abaSMehdi Amini 
13242418abaSMehdi Amini     // The function that we will import!
13342418abaSMehdi Amini     GlobalValue *SGV = Module.getNamedValue(CalledFunctionName);
134*130de7afSTeresa Johnson     StringRef ImportFunctionName = CalledFunctionName;
135*130de7afSTeresa Johnson     if (!SGV) {
136*130de7afSTeresa Johnson       // Might be local in source Module, promoted/renamed in dest Module M.
137*130de7afSTeresa Johnson       std::pair<StringRef, StringRef> Split =
138*130de7afSTeresa Johnson           CalledFunctionName.split(".llvm.");
139*130de7afSTeresa Johnson       SGV = Module.getNamedValue(Split.first);
140*130de7afSTeresa Johnson #ifndef NDEBUG
141*130de7afSTeresa Johnson       // Assert that Split.second is module id
142*130de7afSTeresa Johnson       uint64_t ModuleId;
143*130de7afSTeresa Johnson       assert(!Split.second.getAsInteger(10, ModuleId));
144*130de7afSTeresa Johnson       assert(ModuleId == Index.getModuleId(FileName));
145*130de7afSTeresa Johnson #endif
146*130de7afSTeresa Johnson     }
14742418abaSMehdi Amini     Function *F = dyn_cast<Function>(SGV);
14842418abaSMehdi Amini     if (!F && isa<GlobalAlias>(SGV)) {
14942418abaSMehdi Amini       auto *SGA = dyn_cast<GlobalAlias>(SGV);
15042418abaSMehdi Amini       F = dyn_cast<Function>(SGA->getBaseObject());
151*130de7afSTeresa Johnson       ImportFunctionName = F->getName();
15242418abaSMehdi Amini     }
15342418abaSMehdi Amini     if (!F) {
15442418abaSMehdi Amini       errs() << "Can't load function '" << CalledFunctionName << "' in Module '"
15542418abaSMehdi Amini              << FileName << "', error in the summary?\n";
15642418abaSMehdi Amini       llvm_unreachable("Can't load function in Module");
15742418abaSMehdi Amini     }
15842418abaSMehdi Amini 
15917626654STeresa Johnson     // We cannot import weak_any functions/aliases without possibly affecting
16017626654STeresa Johnson     // the order they are seen and selected by the linker, changing program
16142418abaSMehdi Amini     // semantics.
16217626654STeresa Johnson     if (SGV->hasWeakAnyLinkage()) {
16317626654STeresa Johnson       DEBUG(dbgs() << "Ignoring import request for weak-any "
16417626654STeresa Johnson                    << (isa<Function>(SGV) ? "function " : "alias ")
16542418abaSMehdi Amini                    << CalledFunctionName << " from " << FileName << "\n");
16642418abaSMehdi Amini       continue;
16742418abaSMehdi Amini     }
16842418abaSMehdi Amini 
16942418abaSMehdi Amini     // Link in the specified function.
17042418abaSMehdi Amini     if (L.linkInModule(&Module, Linker::Flags::None, &Index, F))
17142418abaSMehdi Amini       report_fatal_error("Function Import: link error");
17242418abaSMehdi Amini 
173*130de7afSTeresa Johnson     // Process the newly imported function and add callees to the worklist.
174*130de7afSTeresa Johnson     GlobalValue *NewGV = M.getNamedValue(ImportFunctionName);
175*130de7afSTeresa Johnson     assert(NewGV);
176*130de7afSTeresa Johnson     Function *NewF = dyn_cast<Function>(NewGV);
177*130de7afSTeresa Johnson     assert(NewF);
178*130de7afSTeresa Johnson 
179*130de7afSTeresa Johnson     for (auto &BB : *NewF) {
180*130de7afSTeresa Johnson       for (auto &I : BB) {
181*130de7afSTeresa Johnson         if (isa<CallInst>(I)) {
182*130de7afSTeresa Johnson           DEBUG(dbgs() << "Found a call: '" << I << "'\n");
183*130de7afSTeresa Johnson           auto CalledFunction = cast<CallInst>(I).getCalledFunction();
184*130de7afSTeresa Johnson           // Insert any new external calls that have not already been
185*130de7afSTeresa Johnson           // added to set/worklist.
186*130de7afSTeresa Johnson           if (CalledFunction && CalledFunction->hasName() &&
187*130de7afSTeresa Johnson               CalledFunction->isDeclaration() &&
188*130de7afSTeresa Johnson               !CalledFunctions.count(CalledFunction->getName())) {
189*130de7afSTeresa Johnson             CalledFunctions.insert(CalledFunction->getName());
190*130de7afSTeresa Johnson             Worklist.push_back(CalledFunction->getName());
191*130de7afSTeresa Johnson           }
192*130de7afSTeresa Johnson         }
193*130de7afSTeresa Johnson       }
194*130de7afSTeresa Johnson     }
19542418abaSMehdi Amini 
19642418abaSMehdi Amini     Changed = true;
19742418abaSMehdi Amini   }
19842418abaSMehdi Amini   return Changed;
19942418abaSMehdi Amini }
20042418abaSMehdi Amini 
20142418abaSMehdi Amini /// Summary file to use for function importing when using -function-import from
20242418abaSMehdi Amini /// the command line.
20342418abaSMehdi Amini static cl::opt<std::string>
20442418abaSMehdi Amini     SummaryFile("summary-file",
20542418abaSMehdi Amini                 cl::desc("The summary file to use for function importing."));
20642418abaSMehdi Amini 
20742418abaSMehdi Amini static void diagnosticHandler(const DiagnosticInfo &DI) {
20842418abaSMehdi Amini   raw_ostream &OS = errs();
20942418abaSMehdi Amini   DiagnosticPrinterRawOStream DP(OS);
21042418abaSMehdi Amini   DI.print(DP);
21142418abaSMehdi Amini   OS << '\n';
21242418abaSMehdi Amini }
21342418abaSMehdi Amini 
21442418abaSMehdi Amini /// Parse the function index out of an IR file and return the function
21542418abaSMehdi Amini /// index object if found, or nullptr if not.
21642418abaSMehdi Amini static std::unique_ptr<FunctionInfoIndex>
21742418abaSMehdi Amini getFunctionIndexForFile(StringRef Path, std::string &Error,
21842418abaSMehdi Amini                         DiagnosticHandlerFunction DiagnosticHandler) {
21942418abaSMehdi Amini   std::unique_ptr<MemoryBuffer> Buffer;
22042418abaSMehdi Amini   ErrorOr<std::unique_ptr<MemoryBuffer>> BufferOrErr =
22142418abaSMehdi Amini       MemoryBuffer::getFile(Path);
22242418abaSMehdi Amini   if (std::error_code EC = BufferOrErr.getError()) {
22342418abaSMehdi Amini     Error = EC.message();
22442418abaSMehdi Amini     return nullptr;
22542418abaSMehdi Amini   }
22642418abaSMehdi Amini   Buffer = std::move(BufferOrErr.get());
22742418abaSMehdi Amini   ErrorOr<std::unique_ptr<object::FunctionIndexObjectFile>> ObjOrErr =
22842418abaSMehdi Amini       object::FunctionIndexObjectFile::create(Buffer->getMemBufferRef(),
22942418abaSMehdi Amini                                               DiagnosticHandler);
23042418abaSMehdi Amini   if (std::error_code EC = ObjOrErr.getError()) {
23142418abaSMehdi Amini     Error = EC.message();
23242418abaSMehdi Amini     return nullptr;
23342418abaSMehdi Amini   }
23442418abaSMehdi Amini   return (*ObjOrErr)->takeIndex();
23542418abaSMehdi Amini }
23642418abaSMehdi Amini 
23742418abaSMehdi Amini /// Pass that performs cross-module function import provided a summary file.
23842418abaSMehdi Amini class FunctionImportPass : public ModulePass {
23942418abaSMehdi Amini 
24042418abaSMehdi Amini public:
24142418abaSMehdi Amini   /// Pass identification, replacement for typeid
24242418abaSMehdi Amini   static char ID;
24342418abaSMehdi Amini 
24442418abaSMehdi Amini   explicit FunctionImportPass() : ModulePass(ID) {}
24542418abaSMehdi Amini 
24642418abaSMehdi Amini   bool runOnModule(Module &M) override {
24742418abaSMehdi Amini     if (SummaryFile.empty()) {
24842418abaSMehdi Amini       report_fatal_error("error: -function-import requires -summary-file\n");
24942418abaSMehdi Amini     }
25042418abaSMehdi Amini     std::string Error;
25142418abaSMehdi Amini     std::unique_ptr<FunctionInfoIndex> Index =
25242418abaSMehdi Amini         getFunctionIndexForFile(SummaryFile, Error, diagnosticHandler);
25342418abaSMehdi Amini     if (!Index) {
25442418abaSMehdi Amini       errs() << "Error loading file '" << SummaryFile << "': " << Error << "\n";
25542418abaSMehdi Amini       return false;
25642418abaSMehdi Amini     }
25742418abaSMehdi Amini 
25842418abaSMehdi Amini     // Perform the import now.
25942418abaSMehdi Amini     FunctionImporter Importer(M.getContext(), *Index, diagnosticHandler);
26042418abaSMehdi Amini     return Importer.importFunctions(M);
26142418abaSMehdi Amini 
26242418abaSMehdi Amini     return false;
26342418abaSMehdi Amini   }
26442418abaSMehdi Amini };
26542418abaSMehdi Amini 
26642418abaSMehdi Amini char FunctionImportPass::ID = 0;
26742418abaSMehdi Amini INITIALIZE_PASS_BEGIN(FunctionImportPass, "function-import",
26842418abaSMehdi Amini                       "Summary Based Function Import", false, false)
26942418abaSMehdi Amini INITIALIZE_PASS_END(FunctionImportPass, "function-import",
27042418abaSMehdi Amini                     "Summary Based Function Import", false, false)
27142418abaSMehdi Amini 
27242418abaSMehdi Amini namespace llvm {
27342418abaSMehdi Amini Pass *createFunctionImportPass() { return new FunctionImportPass(); }
27442418abaSMehdi Amini }
275