1 //===- Standard pass instrumentations handling ----------------*- C++ -*--===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 /// \file
10 ///
11 /// This file defines IR-printing pass instrumentation callbacks as well as
12 /// StandardInstrumentations class that manages standard pass instrumentations.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Passes/StandardInstrumentations.h"
17 #include "llvm/ADT/Optional.h"
18 #include "llvm/Analysis/CallGraphSCCPass.h"
19 #include "llvm/Analysis/LazyCallGraph.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/IRPrintingPasses.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/IR/PassInstrumentation.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 using namespace llvm;
30 
31 namespace {
32 
33 /// Extracting Module out of \p IR unit. Also fills a textual description
34 /// of \p IR for use in header when printing.
unwrapModule(Any IR)35 Optional<std::pair<const Module *, std::string>> unwrapModule(Any IR) {
36   if (any_isa<const Module *>(IR))
37     return std::make_pair(any_cast<const Module *>(IR), std::string());
38 
39   if (any_isa<const Function *>(IR)) {
40     const Function *F = any_cast<const Function *>(IR);
41     if (!llvm::isFunctionInPrintList(F->getName()))
42       return None;
43     const Module *M = F->getParent();
44     return std::make_pair(M, formatv(" (function: {0})", F->getName()).str());
45   }
46 
47   if (any_isa<const LazyCallGraph::SCC *>(IR)) {
48     const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR);
49     for (const LazyCallGraph::Node &N : *C) {
50       const Function &F = N.getFunction();
51       if (!F.isDeclaration() && isFunctionInPrintList(F.getName())) {
52         const Module *M = F.getParent();
53         return std::make_pair(M, formatv(" (scc: {0})", C->getName()).str());
54       }
55     }
56     return None;
57   }
58 
59   if (any_isa<const Loop *>(IR)) {
60     const Loop *L = any_cast<const Loop *>(IR);
61     const Function *F = L->getHeader()->getParent();
62     if (!isFunctionInPrintList(F->getName()))
63       return None;
64     const Module *M = F->getParent();
65     std::string LoopName;
66     raw_string_ostream ss(LoopName);
67     L->getHeader()->printAsOperand(ss, false);
68     return std::make_pair(M, formatv(" (loop: {0})", ss.str()).str());
69   }
70 
71   llvm_unreachable("Unknown IR unit");
72 }
73 
printIR(const Module * M,StringRef Banner,StringRef Extra=StringRef ())74 void printIR(const Module *M, StringRef Banner, StringRef Extra = StringRef()) {
75   dbgs() << Banner << Extra << "\n";
76   M->print(dbgs(), nullptr, false);
77 }
printIR(const Function * F,StringRef Banner,StringRef Extra=StringRef ())78 void printIR(const Function *F, StringRef Banner,
79              StringRef Extra = StringRef()) {
80   if (!llvm::isFunctionInPrintList(F->getName()))
81     return;
82   dbgs() << Banner << Extra << "\n" << static_cast<const Value &>(*F);
83 }
printIR(const LazyCallGraph::SCC * C,StringRef Banner,StringRef Extra=StringRef ())84 void printIR(const LazyCallGraph::SCC *C, StringRef Banner,
85              StringRef Extra = StringRef()) {
86   bool BannerPrinted = false;
87   for (const LazyCallGraph::Node &N : *C) {
88     const Function &F = N.getFunction();
89     if (!F.isDeclaration() && llvm::isFunctionInPrintList(F.getName())) {
90       if (!BannerPrinted) {
91         dbgs() << Banner << Extra << "\n";
92         BannerPrinted = true;
93       }
94       F.print(dbgs());
95     }
96   }
97 }
printIR(const Loop * L,StringRef Banner)98 void printIR(const Loop *L, StringRef Banner) {
99   const Function *F = L->getHeader()->getParent();
100   if (!llvm::isFunctionInPrintList(F->getName()))
101     return;
102   llvm::printLoop(const_cast<Loop &>(*L), dbgs(), Banner);
103 }
104 
105 /// Generic IR-printing helper that unpacks a pointer to IRUnit wrapped into
106 /// llvm::Any and does actual print job.
unwrapAndPrint(Any IR,StringRef Banner,bool ForceModule=false)107 void unwrapAndPrint(Any IR, StringRef Banner, bool ForceModule = false) {
108   if (ForceModule) {
109     if (auto UnwrappedModule = unwrapModule(IR))
110       printIR(UnwrappedModule->first, Banner, UnwrappedModule->second);
111     return;
112   }
113 
114   if (any_isa<const Module *>(IR)) {
115     const Module *M = any_cast<const Module *>(IR);
116     assert(M && "module should be valid for printing");
117     printIR(M, Banner);
118     return;
119   }
120 
121   if (any_isa<const Function *>(IR)) {
122     const Function *F = any_cast<const Function *>(IR);
123     assert(F && "function should be valid for printing");
124     printIR(F, Banner);
125     return;
126   }
127 
128   if (any_isa<const LazyCallGraph::SCC *>(IR)) {
129     const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR);
130     assert(C && "scc should be valid for printing");
131     std::string Extra = formatv(" (scc: {0})", C->getName());
132     printIR(C, Banner, Extra);
133     return;
134   }
135 
136   if (any_isa<const Loop *>(IR)) {
137     const Loop *L = any_cast<const Loop *>(IR);
138     assert(L && "Loop should be valid for printing");
139     printIR(L, Banner);
140     return;
141   }
142   llvm_unreachable("Unknown wrapped IR type");
143 }
144 
145 } // namespace
146 
~PrintIRInstrumentation()147 PrintIRInstrumentation::~PrintIRInstrumentation() {
148   assert(ModuleDescStack.empty() && "ModuleDescStack is not empty at exit");
149 }
150 
pushModuleDesc(StringRef PassID,Any IR)151 void PrintIRInstrumentation::pushModuleDesc(StringRef PassID, Any IR) {
152   assert(StoreModuleDesc);
153   const Module *M = nullptr;
154   std::string Extra;
155   if (auto UnwrappedModule = unwrapModule(IR))
156     std::tie(M, Extra) = UnwrappedModule.getValue();
157   ModuleDescStack.emplace_back(M, Extra, PassID);
158 }
159 
160 PrintIRInstrumentation::PrintModuleDesc
popModuleDesc(StringRef PassID)161 PrintIRInstrumentation::popModuleDesc(StringRef PassID) {
162   assert(!ModuleDescStack.empty() && "empty ModuleDescStack");
163   PrintModuleDesc ModuleDesc = ModuleDescStack.pop_back_val();
164   assert(std::get<2>(ModuleDesc).equals(PassID) && "malformed ModuleDescStack");
165   return ModuleDesc;
166 }
167 
printBeforePass(StringRef PassID,Any IR)168 bool PrintIRInstrumentation::printBeforePass(StringRef PassID, Any IR) {
169   if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
170     return true;
171 
172   // Saving Module for AfterPassInvalidated operations.
173   // Note: here we rely on a fact that we do not change modules while
174   // traversing the pipeline, so the latest captured module is good
175   // for all print operations that has not happen yet.
176   if (StoreModuleDesc && llvm::shouldPrintAfterPass(PassID))
177     pushModuleDesc(PassID, IR);
178 
179   if (!llvm::shouldPrintBeforePass(PassID))
180     return true;
181 
182   SmallString<20> Banner = formatv("*** IR Dump Before {0} ***", PassID);
183   unwrapAndPrint(IR, Banner, llvm::forcePrintModuleIR());
184   return true;
185 }
186 
printAfterPass(StringRef PassID,Any IR)187 void PrintIRInstrumentation::printAfterPass(StringRef PassID, Any IR) {
188   if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
189     return;
190 
191   if (!llvm::shouldPrintAfterPass(PassID))
192     return;
193 
194   if (StoreModuleDesc)
195     popModuleDesc(PassID);
196 
197   SmallString<20> Banner = formatv("*** IR Dump After {0} ***", PassID);
198   unwrapAndPrint(IR, Banner, llvm::forcePrintModuleIR());
199 }
200 
printAfterPassInvalidated(StringRef PassID)201 void PrintIRInstrumentation::printAfterPassInvalidated(StringRef PassID) {
202   if (!StoreModuleDesc || !llvm::shouldPrintAfterPass(PassID))
203     return;
204 
205   if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
206     return;
207 
208   const Module *M;
209   std::string Extra;
210   StringRef StoredPassID;
211   std::tie(M, Extra, StoredPassID) = popModuleDesc(PassID);
212   // Additional filtering (e.g. -filter-print-func) can lead to module
213   // printing being skipped.
214   if (!M)
215     return;
216 
217   SmallString<20> Banner =
218       formatv("*** IR Dump After {0} *** invalidated: ", PassID);
219   printIR(M, Banner, Extra);
220 }
221 
registerCallbacks(PassInstrumentationCallbacks & PIC)222 void PrintIRInstrumentation::registerCallbacks(
223     PassInstrumentationCallbacks &PIC) {
224   // BeforePass callback is not just for printing, it also saves a Module
225   // for later use in AfterPassInvalidated.
226   StoreModuleDesc = llvm::forcePrintModuleIR() && llvm::shouldPrintAfterPass();
227   if (llvm::shouldPrintBeforePass() || StoreModuleDesc)
228     PIC.registerBeforePassCallback(
229         [this](StringRef P, Any IR) { return this->printBeforePass(P, IR); });
230 
231   if (llvm::shouldPrintAfterPass()) {
232     PIC.registerAfterPassCallback(
233         [this](StringRef P, Any IR) { this->printAfterPass(P, IR); });
234     PIC.registerAfterPassInvalidatedCallback(
235         [this](StringRef P) { this->printAfterPassInvalidated(P); });
236   }
237 }
238 
registerCallbacks(PassInstrumentationCallbacks & PIC)239 void StandardInstrumentations::registerCallbacks(
240     PassInstrumentationCallbacks &PIC) {
241   PrintIR.registerCallbacks(PIC);
242   TimePasses.registerCallbacks(PIC);
243 }
244