1 //===- Standard pass instrumentations handling ----------------*- C++ -*--===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 /// \file
10 ///
11 /// This file defines IR-printing pass instrumentation callbacks as well as
12 /// StandardInstrumentations class that manages standard pass instrumentations.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Passes/StandardInstrumentations.h"
17 #include "llvm/Analysis/CallGraphSCCPass.h"
18 #include "llvm/Analysis/LazyCallGraph.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRPrintingPasses.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/PassInstrumentation.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 using namespace llvm;
29 
30 namespace {
31 namespace PrintIR {
32 
33 //===----------------------------------------------------------------------===//
34 // IR-printing instrumentation
35 //===----------------------------------------------------------------------===//
36 
37 /// Generic IR-printing helper that unpacks a pointer to IRUnit wrapped into
38 /// llvm::Any and does actual print job.
39 void unwrapAndPrint(StringRef Banner, Any IR) {
40   SmallString<40> Extra{"\n"};
41   const Module *M = nullptr;
42   if (any_isa<const Module *>(IR)) {
43     M = any_cast<const Module *>(IR);
44   } else if (any_isa<const Function *>(IR)) {
45     const Function *F = any_cast<const Function *>(IR);
46     if (!llvm::isFunctionInPrintList(F->getName()))
47       return;
48     if (!llvm::forcePrintModuleIR()) {
49       dbgs() << Banner << Extra << static_cast<const Value &>(*F);
50       return;
51     }
52     M = F->getParent();
53     Extra = formatv(" (function: {0})\n", F->getName());
54   } else if (any_isa<const LazyCallGraph::SCC *>(IR)) {
55     const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR);
56     assert(C);
57     if (!llvm::forcePrintModuleIR()) {
58       Extra = formatv(" (scc: {0})\n", C->getName());
59       bool BannerPrinted = false;
60       for (const LazyCallGraph::Node &N : *C) {
61         const Function &F = N.getFunction();
62         if (!F.isDeclaration() && isFunctionInPrintList(F.getName())) {
63           if (!BannerPrinted) {
64             dbgs() << Banner << Extra;
65             BannerPrinted = true;
66           }
67           F.print(dbgs());
68         }
69       }
70       return;
71     }
72     for (const LazyCallGraph::Node &N : *C) {
73       const Function &F = N.getFunction();
74       if (!F.isDeclaration() && isFunctionInPrintList(F.getName())) {
75         M = F.getParent();
76         break;
77       }
78     }
79     if (!M)
80       return;
81     Extra = formatv(" (for scc: {0})\n", C->getName());
82   } else if (any_isa<const Loop *>(IR)) {
83     const Loop *L = any_cast<const Loop *>(IR);
84     const Function *F = L->getHeader()->getParent();
85     if (!isFunctionInPrintList(F->getName()))
86       return;
87     if (!llvm::forcePrintModuleIR()) {
88       llvm::printLoop(const_cast<Loop &>(*L), dbgs(), Banner);
89       return;
90     }
91     M = F->getParent();
92     {
93       std::string LoopName;
94       raw_string_ostream ss(LoopName);
95       L->getHeader()->printAsOperand(ss, false);
96       Extra = formatv(" (loop: {0})\n", ss.str());
97     }
98   }
99   if (M) {
100     dbgs() << Banner << Extra;
101     M->print(dbgs(), nullptr, false);
102   } else {
103     llvm_unreachable("Unknown wrapped IR type");
104   }
105 }
106 
107 bool printBeforePass(StringRef PassID, Any IR) {
108   if (!llvm::shouldPrintBeforePass(PassID))
109     return true;
110 
111   if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
112     return true;
113 
114   SmallString<20> Banner = formatv("*** IR Dump Before {0} ***", PassID);
115   unwrapAndPrint(Banner, IR);
116   return true;
117 }
118 
119 void printAfterPass(StringRef PassID, Any IR) {
120   if (!llvm::shouldPrintAfterPass(PassID))
121     return;
122 
123   if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
124     return;
125 
126   SmallString<20> Banner = formatv("*** IR Dump After {0} ***", PassID);
127   unwrapAndPrint(Banner, IR);
128   return;
129 }
130 } // namespace PrintIR
131 } // namespace
132 
133 void StandardInstrumentations::registerCallbacks(
134     PassInstrumentationCallbacks &PIC) {
135   if (llvm::shouldPrintBeforePass())
136     PIC.registerBeforePassCallback(PrintIR::printBeforePass);
137   if (llvm::shouldPrintAfterPass())
138     PIC.registerAfterPassCallback(PrintIR::printAfterPass);
139   TimePasses.registerCallbacks(PIC);
140 }
141