1 //===-- WebAssemblyAddMissingPrototypes.cpp - Fix prototypeless functions -===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// \file
11 /// Add prototypes to prototypes-less functions.
12 ///
13 /// WebAssembly has strict function prototype checking so we need functions
14 /// declarations to match the call sites.  Clang treats prototype-less functions
15 /// as varargs (foo(...)) which happens to work on existing platforms but
16 /// doesn't under WebAssembly.  This pass will find all the call sites of each
17 /// prototype-less function, ensure they agree, and then set the signature
18 /// on the function declaration accordingly.
19 ///
20 //===----------------------------------------------------------------------===//
21 
22 #include "WebAssembly.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/IR/Operator.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Transforms/Utils/Local.h"
30 #include "llvm/Transforms/Utils/ModuleUtils.h"
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "wasm-add-missing-prototypes"
34 
35 namespace {
36 class WebAssemblyAddMissingPrototypes final : public ModulePass {
getPassName() const37   StringRef getPassName() const override {
38     return "Add prototypes to prototypes-less functions";
39   }
40 
getAnalysisUsage(AnalysisUsage & AU) const41   void getAnalysisUsage(AnalysisUsage &AU) const override {
42     AU.setPreservesCFG();
43     ModulePass::getAnalysisUsage(AU);
44   }
45 
46   bool runOnModule(Module &M) override;
47 
48 public:
49   static char ID;
WebAssemblyAddMissingPrototypes()50   WebAssemblyAddMissingPrototypes() : ModulePass(ID) {}
51 };
52 } // End anonymous namespace
53 
54 char WebAssemblyAddMissingPrototypes::ID = 0;
55 INITIALIZE_PASS(WebAssemblyAddMissingPrototypes, DEBUG_TYPE,
56                 "Add prototypes to prototypes-less functions", false, false)
57 
createWebAssemblyAddMissingPrototypes()58 ModulePass *llvm::createWebAssemblyAddMissingPrototypes() {
59   return new WebAssemblyAddMissingPrototypes();
60 }
61 
runOnModule(Module & M)62 bool WebAssemblyAddMissingPrototypes::runOnModule(Module &M) {
63   LLVM_DEBUG(dbgs() << "********** Add Missing Prototypes **********\n");
64 
65   std::vector<std::pair<Function *, Function *>> Replacements;
66 
67   // Find all the prototype-less function declarations
68   for (Function &F : M) {
69     if (!F.isDeclaration() || !F.hasFnAttribute("no-prototype"))
70       continue;
71 
72     LLVM_DEBUG(dbgs() << "Found no-prototype function: " << F.getName()
73                       << "\n");
74 
75     // When clang emits prototype-less C functions it uses (...), i.e. varargs
76     // function that take no arguments (have no sentinel).  When we see a
77     // no-prototype attribute we expect the function have these properties.
78     if (!F.isVarArg())
79       report_fatal_error(
80           "Functions with 'no-prototype' attribute must take varargs: " +
81           F.getName());
82     if (F.getFunctionType()->getNumParams() != 0)
83       report_fatal_error(
84           "Functions with 'no-prototype' attribute should not have params: " +
85           F.getName());
86 
87     // Create a function prototype based on the first call site (first bitcast)
88     // that we find.
89     FunctionType *NewType = nullptr;
90     Function *NewF = nullptr;
91     for (Use &U : F.uses()) {
92       LLVM_DEBUG(dbgs() << "prototype-less use: " << F.getName() << "\n");
93       if (auto *BC = dyn_cast<BitCastOperator>(U.getUser())) {
94         if (auto *DestType = dyn_cast<FunctionType>(
95                 BC->getDestTy()->getPointerElementType())) {
96           if (!NewType) {
97             // Create a new function with the correct type
98             NewType = DestType;
99             NewF = Function::Create(NewType, F.getLinkage(), F.getName());
100             NewF->setAttributes(F.getAttributes());
101             NewF->removeFnAttr("no-prototype");
102           } else {
103             if (NewType != DestType) {
104               report_fatal_error("Prototypeless function used with "
105                                  "conflicting signatures: " +
106                                  F.getName());
107             }
108           }
109         }
110       }
111     }
112 
113     if (!NewType) {
114       LLVM_DEBUG(
115           dbgs() << "could not derive a function prototype from usage: " +
116                         F.getName() + "\n");
117       continue;
118     }
119 
120     SmallVector<Instruction *, 4> DeadInsts;
121 
122     for (Use &US : F.uses()) {
123       User *U = US.getUser();
124       if (auto *BC = dyn_cast<BitCastOperator>(U)) {
125         if (auto *Inst = dyn_cast<BitCastInst>(U)) {
126           // Replace with a new bitcast
127           IRBuilder<> Builder(Inst);
128           Value *NewCast = Builder.CreatePointerCast(NewF, BC->getDestTy());
129           Inst->replaceAllUsesWith(NewCast);
130           DeadInsts.push_back(Inst);
131         } else if (auto *Const = dyn_cast<ConstantExpr>(U)) {
132           Constant *NewConst =
133               ConstantExpr::getPointerCast(NewF, BC->getDestTy());
134           Const->replaceAllUsesWith(NewConst);
135         } else {
136           dbgs() << *U->getType() << "\n";
137 #ifndef NDEBUG
138           U->dump();
139 #endif
140           report_fatal_error("unexpected use of prototypeless function: " +
141                              F.getName() + "\n");
142         }
143       }
144     }
145 
146     for (auto I : DeadInsts)
147       I->eraseFromParent();
148     Replacements.emplace_back(&F, NewF);
149   }
150 
151 
152   // Finally replace the old function declarations with the new ones
153   for (auto &Pair : Replacements) {
154     Function *Old = Pair.first;
155     Function *New = Pair.second;
156     Old->eraseFromParent();
157     M.getFunctionList().push_back(New);
158   }
159 
160   return !Replacements.empty();
161 }
162