124e2fe98SDimitry Andric //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
224e2fe98SDimitry Andric //
324e2fe98SDimitry Andric //                     The LLVM Compiler Infrastructure
424e2fe98SDimitry Andric //
524e2fe98SDimitry Andric // This file is distributed under the University of Illinois Open Source
624e2fe98SDimitry Andric // License. See LICENSE.TXT for details.
724e2fe98SDimitry Andric //
824e2fe98SDimitry Andric //===----------------------------------------------------------------------===//
924e2fe98SDimitry Andric ///
1024e2fe98SDimitry Andric /// \file
1124e2fe98SDimitry Andric /// \brief Fix bitcasted functions.
1224e2fe98SDimitry Andric ///
1324e2fe98SDimitry Andric /// WebAssembly requires caller and callee signatures to match, however in LLVM,
1424e2fe98SDimitry Andric /// some amount of slop is vaguely permitted. Detect mismatch by looking for
1524e2fe98SDimitry Andric /// bitcasts of functions and rewrite them to use wrapper functions instead.
1624e2fe98SDimitry Andric ///
1724e2fe98SDimitry Andric /// This doesn't catch all cases, such as when a function's address is taken in
1824e2fe98SDimitry Andric /// one place and casted in another, but it works for many common cases.
1924e2fe98SDimitry Andric ///
2024e2fe98SDimitry Andric /// Note that LLVM already optimizes away function bitcasts in common cases by
2124e2fe98SDimitry Andric /// dropping arguments as needed, so this pass only ends up getting used in less
2224e2fe98SDimitry Andric /// common cases.
2324e2fe98SDimitry Andric ///
2424e2fe98SDimitry Andric //===----------------------------------------------------------------------===//
2524e2fe98SDimitry Andric 
2624e2fe98SDimitry Andric #include "WebAssembly.h"
2724e2fe98SDimitry Andric #include "llvm/IR/Constants.h"
2824e2fe98SDimitry Andric #include "llvm/IR/Instructions.h"
2924e2fe98SDimitry Andric #include "llvm/IR/Module.h"
3024e2fe98SDimitry Andric #include "llvm/IR/Operator.h"
3124e2fe98SDimitry Andric #include "llvm/Pass.h"
3224e2fe98SDimitry Andric #include "llvm/Support/Debug.h"
3324e2fe98SDimitry Andric #include "llvm/Support/raw_ostream.h"
3424e2fe98SDimitry Andric using namespace llvm;
3524e2fe98SDimitry Andric 
3624e2fe98SDimitry Andric #define DEBUG_TYPE "wasm-fix-function-bitcasts"
3724e2fe98SDimitry Andric 
3824e2fe98SDimitry Andric namespace {
3924e2fe98SDimitry Andric class FixFunctionBitcasts final : public ModulePass {
4024e2fe98SDimitry Andric   StringRef getPassName() const override {
4124e2fe98SDimitry Andric     return "WebAssembly Fix Function Bitcasts";
4224e2fe98SDimitry Andric   }
4324e2fe98SDimitry Andric 
4424e2fe98SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
4524e2fe98SDimitry Andric     AU.setPreservesCFG();
4624e2fe98SDimitry Andric     ModulePass::getAnalysisUsage(AU);
4724e2fe98SDimitry Andric   }
4824e2fe98SDimitry Andric 
4924e2fe98SDimitry Andric   bool runOnModule(Module &M) override;
5024e2fe98SDimitry Andric 
5124e2fe98SDimitry Andric public:
5224e2fe98SDimitry Andric   static char ID;
5324e2fe98SDimitry Andric   FixFunctionBitcasts() : ModulePass(ID) {}
5424e2fe98SDimitry Andric };
5524e2fe98SDimitry Andric } // End anonymous namespace
5624e2fe98SDimitry Andric 
5724e2fe98SDimitry Andric char FixFunctionBitcasts::ID = 0;
5824e2fe98SDimitry Andric ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
5924e2fe98SDimitry Andric   return new FixFunctionBitcasts();
6024e2fe98SDimitry Andric }
6124e2fe98SDimitry Andric 
6224e2fe98SDimitry Andric // Recursively descend the def-use lists from V to find non-bitcast users of
6324e2fe98SDimitry Andric // bitcasts of V.
6424e2fe98SDimitry Andric static void FindUses(Value *V, Function &F,
6524e2fe98SDimitry Andric                      SmallVectorImpl<std::pair<Use *, Function *>> &Uses) {
6624e2fe98SDimitry Andric   for (Use &U : V->uses()) {
6724e2fe98SDimitry Andric     if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser()))
6824e2fe98SDimitry Andric       FindUses(BC, F, Uses);
6924e2fe98SDimitry Andric     else if (U.get()->getType() != F.getType())
7024e2fe98SDimitry Andric       Uses.push_back(std::make_pair(&U, &F));
7124e2fe98SDimitry Andric   }
7224e2fe98SDimitry Andric }
7324e2fe98SDimitry Andric 
7424e2fe98SDimitry Andric // Create a wrapper function with type Ty that calls F (which may have a
7524e2fe98SDimitry Andric // different type). Attempt to support common bitcasted function idioms:
7624e2fe98SDimitry Andric //  - Call with more arguments than needed: arguments are dropped
7724e2fe98SDimitry Andric //  - Call with fewer arguments than needed: arguments are filled in with undef
7824e2fe98SDimitry Andric //  - Return value is not needed: drop it
7924e2fe98SDimitry Andric //  - Return value needed but not present: supply an undef
8024e2fe98SDimitry Andric //
8124e2fe98SDimitry Andric // For now, return nullptr without creating a wrapper if the wrapper cannot
8224e2fe98SDimitry Andric // be generated due to incompatible types.
8324e2fe98SDimitry Andric static Function *CreateWrapper(Function *F, FunctionType *Ty) {
8424e2fe98SDimitry Andric   Module *M = F->getParent();
8524e2fe98SDimitry Andric 
8624e2fe98SDimitry Andric   Function *Wrapper =
8724e2fe98SDimitry Andric       Function::Create(Ty, Function::PrivateLinkage, "bitcast", M);
8824e2fe98SDimitry Andric   BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
8924e2fe98SDimitry Andric 
9024e2fe98SDimitry Andric   // Determine what arguments to pass.
9124e2fe98SDimitry Andric   SmallVector<Value *, 4> Args;
9224e2fe98SDimitry Andric   Function::arg_iterator AI = Wrapper->arg_begin();
9324e2fe98SDimitry Andric   FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
9424e2fe98SDimitry Andric   FunctionType::param_iterator PE = F->getFunctionType()->param_end();
9524e2fe98SDimitry Andric   for (; AI != Wrapper->arg_end() && PI != PE; ++AI, ++PI) {
9624e2fe98SDimitry Andric     if (AI->getType() != *PI) {
9724e2fe98SDimitry Andric       Wrapper->eraseFromParent();
9824e2fe98SDimitry Andric       return nullptr;
9924e2fe98SDimitry Andric     }
10024e2fe98SDimitry Andric     Args.push_back(&*AI);
10124e2fe98SDimitry Andric   }
10224e2fe98SDimitry Andric   for (; PI != PE; ++PI)
10324e2fe98SDimitry Andric     Args.push_back(UndefValue::get(*PI));
10424e2fe98SDimitry Andric 
10524e2fe98SDimitry Andric   CallInst *Call = CallInst::Create(F, Args, "", BB);
10624e2fe98SDimitry Andric 
10724e2fe98SDimitry Andric   // Determine what value to return.
10824e2fe98SDimitry Andric   if (Ty->getReturnType()->isVoidTy())
10924e2fe98SDimitry Andric     ReturnInst::Create(M->getContext(), BB);
11024e2fe98SDimitry Andric   else if (F->getFunctionType()->getReturnType()->isVoidTy())
11124e2fe98SDimitry Andric     ReturnInst::Create(M->getContext(), UndefValue::get(Ty->getReturnType()),
11224e2fe98SDimitry Andric                        BB);
11324e2fe98SDimitry Andric   else if (F->getFunctionType()->getReturnType() == Ty->getReturnType())
11424e2fe98SDimitry Andric     ReturnInst::Create(M->getContext(), Call, BB);
11524e2fe98SDimitry Andric   else {
11624e2fe98SDimitry Andric     Wrapper->eraseFromParent();
11724e2fe98SDimitry Andric     return nullptr;
11824e2fe98SDimitry Andric   }
11924e2fe98SDimitry Andric 
12024e2fe98SDimitry Andric   return Wrapper;
12124e2fe98SDimitry Andric }
12224e2fe98SDimitry Andric 
12324e2fe98SDimitry Andric bool FixFunctionBitcasts::runOnModule(Module &M) {
12424e2fe98SDimitry Andric   SmallVector<std::pair<Use *, Function *>, 0> Uses;
12524e2fe98SDimitry Andric 
12624e2fe98SDimitry Andric   // Collect all the places that need wrappers.
12724e2fe98SDimitry Andric   for (Function &F : M)
12824e2fe98SDimitry Andric     FindUses(&F, F, Uses);
12924e2fe98SDimitry Andric 
13024e2fe98SDimitry Andric   DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
13124e2fe98SDimitry Andric 
13224e2fe98SDimitry Andric   for (auto &UseFunc : Uses) {
13324e2fe98SDimitry Andric     Use *U = UseFunc.first;
13424e2fe98SDimitry Andric     Function *F = UseFunc.second;
13524e2fe98SDimitry Andric     PointerType *PTy = cast<PointerType>(U->get()->getType());
13624e2fe98SDimitry Andric     FunctionType *Ty = dyn_cast<FunctionType>(PTy->getElementType());
13724e2fe98SDimitry Andric 
13824e2fe98SDimitry Andric     // If the function is casted to something like i8* as a "generic pointer"
13924e2fe98SDimitry Andric     // to be later casted to something else, we can't generate a wrapper for it.
14024e2fe98SDimitry Andric     // Just ignore such casts for now.
14124e2fe98SDimitry Andric     if (!Ty)
14224e2fe98SDimitry Andric       continue;
14324e2fe98SDimitry Andric 
14424e2fe98SDimitry Andric     auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
14524e2fe98SDimitry Andric     if (Pair.second)
14624e2fe98SDimitry Andric       Pair.first->second = CreateWrapper(F, Ty);
14724e2fe98SDimitry Andric 
14824e2fe98SDimitry Andric     Function *Wrapper = Pair.first->second;
14924e2fe98SDimitry Andric     if (!Wrapper)
15024e2fe98SDimitry Andric       continue;
15124e2fe98SDimitry Andric 
15224e2fe98SDimitry Andric     if (isa<Constant>(U->get()))
15324e2fe98SDimitry Andric       U->get()->replaceAllUsesWith(Wrapper);
15424e2fe98SDimitry Andric     else
15524e2fe98SDimitry Andric       U->set(Wrapper);
15624e2fe98SDimitry Andric   }
15724e2fe98SDimitry Andric 
15824e2fe98SDimitry Andric   return true;
15924e2fe98SDimitry Andric }
160