1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "ReduceOperandsToArgs.h" 10 #include "Delta.h" 11 #include "llvm/ADT/Sequence.h" 12 #include "llvm/IR/Constants.h" 13 #include "llvm/IR/InstIterator.h" 14 #include "llvm/IR/InstrTypes.h" 15 #include "llvm/IR/Instructions.h" 16 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 17 #include "llvm/Transforms/Utils/Cloning.h" 18 19 using namespace llvm; 20 21 static bool canReplaceFunction(Function *F) { 22 return all_of(F->uses(), [](Use &Op) { 23 if (auto *CI = dyn_cast<CallBase>(Op.getUser())) 24 return &CI->getCalledOperandUse() == &Op; 25 return false; 26 }); 27 } 28 29 static bool canReduceUse(Use &Op) { 30 Value *Val = Op.get(); 31 Type *Ty = Val->getType(); 32 33 // Only replace operands that can be passed-by-value. 34 if (!Ty->isFirstClassType()) 35 return false; 36 37 // Don't pass labels/metadata as arguments. 38 if (Ty->isLabelTy() || Ty->isMetadataTy()) 39 return false; 40 41 // No need to replace values that are already arguments. 42 if (isa<Argument>(Val)) 43 return false; 44 45 // Do not replace literals. 46 if (isa<ConstantData>(Val)) 47 return false; 48 49 // Do not convert direct function calls to indirect calls. 50 if (auto *CI = dyn_cast<CallBase>(Op.getUser())) 51 if (&CI->getCalledOperandUse() == &Op) 52 return false; 53 54 return true; 55 } 56 57 /// Goes over OldF calls and replaces them with a call to NewF. 58 static void replaceFunctionCalls(Function *OldF, Function *NewF) { 59 SmallVector<CallBase *> Callers; 60 for (Use &U : OldF->uses()) { 61 auto *CI = cast<CallBase>(U.getUser()); 62 assert(&U == &CI->getCalledOperandUse()); 63 assert(CI->getCalledFunction() == OldF); 64 Callers.push_back(CI); 65 } 66 67 // Call arguments for NewF. 68 SmallVector<Value *> Args(NewF->arg_size(), nullptr); 69 70 // Fill up the additional parameters with undef values. 71 for (auto ArgIdx : llvm::seq<size_t>(OldF->arg_size(), NewF->arg_size())) { 72 Type *NewArgTy = NewF->getArg(ArgIdx)->getType(); 73 Args[ArgIdx] = UndefValue::get(NewArgTy); 74 } 75 76 for (CallBase *CI : Callers) { 77 // Preserve the original function arguments. 78 for (auto Z : zip_first(CI->args(), Args)) 79 std::get<1>(Z) = std::get<0>(Z); 80 81 // Also preserve operand bundles. 82 SmallVector<OperandBundleDef> OperandBundles; 83 CI->getOperandBundlesAsDefs(OperandBundles); 84 85 // Create the new function call. 86 CallBase *NewCI; 87 if (auto *II = dyn_cast<InvokeInst>(CI)) { 88 NewCI = InvokeInst::Create(NewF, cast<InvokeInst>(II)->getNormalDest(), 89 cast<InvokeInst>(II)->getUnwindDest(), Args, 90 OperandBundles, CI->getName()); 91 } else { 92 assert(isa<CallInst>(CI)); 93 NewCI = CallInst::Create(NewF, Args, OperandBundles, CI->getName()); 94 } 95 NewCI->setCallingConv(NewF->getCallingConv()); 96 97 // Do the replacement for this use. 98 if (!CI->use_empty()) 99 CI->replaceAllUsesWith(NewCI); 100 ReplaceInstWithInst(CI, NewCI); 101 } 102 } 103 104 /// Add a new function argument to @p F for each use in @OpsToReplace, and 105 /// replace those operand values with the new function argument. 106 static void substituteOperandWithArgument(Function *OldF, 107 ArrayRef<Use *> OpsToReplace) { 108 if (OpsToReplace.empty()) 109 return; 110 111 SetVector<Value *> UniqueValues; 112 for (Use *Op : OpsToReplace) 113 UniqueValues.insert(Op->get()); 114 115 // Determine the new function's signature. 116 SmallVector<Type *> NewArgTypes; 117 llvm::append_range(NewArgTypes, OldF->getFunctionType()->params()); 118 size_t ArgOffset = NewArgTypes.size(); 119 for (Value *V : UniqueValues) 120 NewArgTypes.push_back(V->getType()); 121 FunctionType *FTy = 122 FunctionType::get(OldF->getFunctionType()->getReturnType(), NewArgTypes, 123 OldF->getFunctionType()->isVarArg()); 124 125 // Create the new function... 126 Function *NewF = 127 Function::Create(FTy, OldF->getLinkage(), OldF->getAddressSpace(), 128 OldF->getName(), OldF->getParent()); 129 130 // In order to preserve function order, we move NewF behind OldF 131 NewF->removeFromParent(); 132 OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF); 133 134 // Preserve the parameters of OldF. 135 ValueToValueMapTy VMap; 136 for (auto Z : zip_first(OldF->args(), NewF->args())) { 137 Argument &OldArg = std::get<0>(Z); 138 Argument &NewArg = std::get<1>(Z); 139 140 NewArg.setName(OldArg.getName()); // Copy the name over... 141 VMap[&OldArg] = &NewArg; // Add mapping to VMap 142 } 143 144 // Adjust the new parameters. 145 ValueToValueMapTy OldValMap; 146 for (auto Z : zip_first(UniqueValues, drop_begin(NewF->args(), ArgOffset))) { 147 Value *OldVal = std::get<0>(Z); 148 Argument &NewArg = std::get<1>(Z); 149 150 NewArg.setName(OldVal->getName()); 151 OldValMap[OldVal] = &NewArg; 152 } 153 154 SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned. 155 CloneFunctionInto(NewF, OldF, VMap, CloneFunctionChangeType::LocalChangesOnly, 156 Returns, "", /*CodeInfo=*/nullptr); 157 158 // Replace the actual operands. 159 for (Use *Op : OpsToReplace) { 160 Value *NewArg = OldValMap.lookup(Op->get()); 161 auto *NewUser = cast<Instruction>(VMap.lookup(Op->getUser())); 162 NewUser->setOperand(Op->getOperandNo(), NewArg); 163 } 164 165 // Replace all OldF uses with NewF. 166 replaceFunctionCalls(OldF, NewF); 167 168 // Rename NewF to OldF's name. 169 std::string FName = OldF->getName().str(); 170 OldF->replaceAllUsesWith(ConstantExpr::getBitCast(NewF, OldF->getType())); 171 OldF->eraseFromParent(); 172 NewF->setName(FName); 173 } 174 175 static void reduceOperandsToArgs(Oracle &O, Module &Program) { 176 SmallVector<Use *> OperandsToReduce; 177 for (Function &F : make_early_inc_range(Program.functions())) { 178 if (!canReplaceFunction(&F)) 179 continue; 180 OperandsToReduce.clear(); 181 for (Instruction &I : instructions(&F)) { 182 for (Use &Op : I.operands()) { 183 if (!canReduceUse(Op)) 184 continue; 185 if (O.shouldKeep()) 186 continue; 187 188 OperandsToReduce.push_back(&Op); 189 } 190 } 191 192 substituteOperandWithArgument(&F, OperandsToReduce); 193 } 194 } 195 196 void llvm::reduceOperandsToArgsDeltaPass(TestRunner &Test) { 197 outs() << "*** Converting operands to function arguments ...\n"; 198 return runDeltaPass(Test, reduceOperandsToArgs); 199 } 200