1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ReduceOperandsToArgs.h"
10 #include "Delta.h"
11 #include "llvm/ADT/Sequence.h"
12 #include "llvm/IR/Constants.h"
13 #include "llvm/IR/InstIterator.h"
14 #include "llvm/IR/InstrTypes.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
17 #include "llvm/Transforms/Utils/Cloning.h"
18 
19 using namespace llvm;
20 
21 static bool canReplaceFunction(Function *F) {
22   return all_of(F->uses(), [](Use &Op) {
23     if (auto *CI = dyn_cast<CallBase>(Op.getUser()))
24       return &CI->getCalledOperandUse() == &Op;
25     return false;
26   });
27 }
28 
29 static bool canReduceUse(Use &Op) {
30   Value *Val = Op.get();
31   Type *Ty = Val->getType();
32 
33   // Only replace operands that can be passed-by-value.
34   if (!Ty->isFirstClassType())
35     return false;
36 
37   // Don't pass labels/metadata as arguments.
38   if (Ty->isLabelTy() || Ty->isMetadataTy())
39     return false;
40 
41   // No need to replace values that are already arguments.
42   if (isa<Argument>(Val))
43     return false;
44 
45   // Do not replace literals.
46   if (isa<ConstantData>(Val))
47     return false;
48 
49   // Do not convert direct function calls to indirect calls.
50   if (auto *CI = dyn_cast<CallBase>(Op.getUser()))
51     if (&CI->getCalledOperandUse() == &Op)
52       return false;
53 
54   return true;
55 }
56 
57 /// Goes over OldF calls and replaces them with a call to NewF.
58 static void replaceFunctionCalls(Function *OldF, Function *NewF) {
59   SmallVector<CallBase *> Callers;
60   for (Use &U : OldF->uses()) {
61     auto *CI = cast<CallBase>(U.getUser());
62     assert(&U == &CI->getCalledOperandUse());
63     assert(CI->getCalledFunction() == OldF);
64     Callers.push_back(CI);
65   }
66 
67   // Call arguments for NewF.
68   SmallVector<Value *> Args(NewF->arg_size(), nullptr);
69 
70   // Fill up the additional parameters with undef values.
71   for (auto ArgIdx : llvm::seq<size_t>(OldF->arg_size(), NewF->arg_size())) {
72     Type *NewArgTy = NewF->getArg(ArgIdx)->getType();
73     Args[ArgIdx] = UndefValue::get(NewArgTy);
74   }
75 
76   for (CallBase *CI : Callers) {
77     // Preserve the original function arguments.
78     for (auto Z : zip_first(CI->args(), Args))
79       std::get<1>(Z) = std::get<0>(Z);
80 
81     // Also preserve operand bundles.
82     SmallVector<OperandBundleDef> OperandBundles;
83     CI->getOperandBundlesAsDefs(OperandBundles);
84 
85     // Create the new function call.
86     CallBase *NewCI;
87     if (auto *II = dyn_cast<InvokeInst>(CI)) {
88       NewCI = InvokeInst::Create(NewF, cast<InvokeInst>(II)->getNormalDest(),
89                                  cast<InvokeInst>(II)->getUnwindDest(), Args,
90                                  OperandBundles, CI->getName());
91     } else {
92       assert(isa<CallInst>(CI));
93       NewCI = CallInst::Create(NewF, Args, OperandBundles, CI->getName());
94     }
95     NewCI->setCallingConv(NewF->getCallingConv());
96 
97     // Do the replacement for this use.
98     if (!CI->use_empty())
99       CI->replaceAllUsesWith(NewCI);
100     ReplaceInstWithInst(CI, NewCI);
101   }
102 }
103 
104 /// Add a new function argument to @p F for each use in @OpsToReplace, and
105 /// replace those operand values with the new function argument.
106 static void substituteOperandWithArgument(Function *OldF,
107                                           ArrayRef<Use *> OpsToReplace) {
108   if (OpsToReplace.empty())
109     return;
110 
111   SetVector<Value *> UniqueValues;
112   for (Use *Op : OpsToReplace)
113     UniqueValues.insert(Op->get());
114 
115   // Determine the new function's signature.
116   SmallVector<Type *> NewArgTypes;
117   llvm::append_range(NewArgTypes, OldF->getFunctionType()->params());
118   size_t ArgOffset = NewArgTypes.size();
119   for (Value *V : UniqueValues)
120     NewArgTypes.push_back(V->getType());
121   FunctionType *FTy =
122       FunctionType::get(OldF->getFunctionType()->getReturnType(), NewArgTypes,
123                         OldF->getFunctionType()->isVarArg());
124 
125   // Create the new function...
126   Function *NewF =
127       Function::Create(FTy, OldF->getLinkage(), OldF->getAddressSpace(),
128                        OldF->getName(), OldF->getParent());
129 
130   // In order to preserve function order, we move NewF behind OldF
131   NewF->removeFromParent();
132   OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF);
133 
134   // Preserve the parameters of OldF.
135   ValueToValueMapTy VMap;
136   for (auto Z : zip_first(OldF->args(), NewF->args())) {
137     Argument &OldArg = std::get<0>(Z);
138     Argument &NewArg = std::get<1>(Z);
139 
140     NewArg.setName(OldArg.getName()); // Copy the name over...
141     VMap[&OldArg] = &NewArg;          // Add mapping to VMap
142   }
143 
144   // Adjust the new parameters.
145   ValueToValueMapTy OldValMap;
146   for (auto Z : zip_first(UniqueValues, drop_begin(NewF->args(), ArgOffset))) {
147     Value *OldVal = std::get<0>(Z);
148     Argument &NewArg = std::get<1>(Z);
149 
150     NewArg.setName(OldVal->getName());
151     OldValMap[OldVal] = &NewArg;
152   }
153 
154   SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
155   CloneFunctionInto(NewF, OldF, VMap, CloneFunctionChangeType::LocalChangesOnly,
156                     Returns, "", /*CodeInfo=*/nullptr);
157 
158   // Replace the actual operands.
159   for (Use *Op : OpsToReplace) {
160     Value *NewArg = OldValMap.lookup(Op->get());
161     auto *NewUser = cast<Instruction>(VMap.lookup(Op->getUser()));
162     NewUser->setOperand(Op->getOperandNo(), NewArg);
163   }
164 
165   // Replace all OldF uses with NewF.
166   replaceFunctionCalls(OldF, NewF);
167 
168   // Rename NewF to OldF's name.
169   std::string FName = OldF->getName().str();
170   OldF->replaceAllUsesWith(ConstantExpr::getBitCast(NewF, OldF->getType()));
171   OldF->eraseFromParent();
172   NewF->setName(FName);
173 }
174 
175 static void reduceOperandsToArgs(Oracle &O, Module &Program) {
176   SmallVector<Use *> OperandsToReduce;
177   for (Function &F : make_early_inc_range(Program.functions())) {
178     if (!canReplaceFunction(&F))
179       continue;
180     OperandsToReduce.clear();
181     for (Instruction &I : instructions(&F)) {
182       for (Use &Op : I.operands()) {
183         if (!canReduceUse(Op))
184           continue;
185         if (O.shouldKeep())
186           continue;
187 
188         OperandsToReduce.push_back(&Op);
189       }
190     }
191 
192     substituteOperandWithArgument(&F, OperandsToReduce);
193   }
194 }
195 
196 void llvm::reduceOperandsToArgsDeltaPass(TestRunner &Test) {
197   outs() << "*** Converting operands to function arguments ...\n";
198   return runDeltaPass(Test, reduceOperandsToArgs);
199 }
200