1 //===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass promotes "by reference" arguments to be "by value" arguments.  In
10 // practice, this means looking for internal functions that have pointer
11 // arguments.  If it can prove, through the use of alias analysis, that an
12 // argument is *only* loaded, then it can pass the value into the function
13 // instead of the address of the value.  This can cause recursive simplification
14 // of code and lead to the elimination of allocas (especially in C++ template
15 // code like the STL).
16 //
17 // This pass also handles aggregate arguments that are passed into a function,
18 // scalarizing them if the elements of the aggregate are only loaded.  Note that
19 // by default it refuses to scalarize aggregates which would require passing in
20 // more than three operands to the function, because passing thousands of
21 // operands for a large array or structure is unprofitable! This limit can be
22 // configured or disabled, however.
23 //
24 // Note that this transformation could also be done for arguments that are only
25 // stored to (returning the value instead), but does not currently.  This case
26 // would be best handled when and if LLVM begins supporting multiple return
27 // values from functions.
28 //
29 //===----------------------------------------------------------------------===//
30 
31 #include "llvm/Transforms/IPO/ArgumentPromotion.h"
32 #include "llvm/ADT/DepthFirstIterator.h"
33 #include "llvm/ADT/None.h"
34 #include "llvm/ADT/Optional.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/SmallPtrSet.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/Statistic.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/Analysis/AssumptionCache.h"
42 #include "llvm/Analysis/BasicAliasAnalysis.h"
43 #include "llvm/Analysis/CGSCCPassManager.h"
44 #include "llvm/Analysis/CallGraph.h"
45 #include "llvm/Analysis/CallGraphSCCPass.h"
46 #include "llvm/Analysis/LazyCallGraph.h"
47 #include "llvm/Analysis/Loads.h"
48 #include "llvm/Analysis/MemoryLocation.h"
49 #include "llvm/Analysis/TargetLibraryInfo.h"
50 #include "llvm/Analysis/TargetTransformInfo.h"
51 #include "llvm/Analysis/ValueTracking.h"
52 #include "llvm/IR/Argument.h"
53 #include "llvm/IR/Attributes.h"
54 #include "llvm/IR/BasicBlock.h"
55 #include "llvm/IR/CFG.h"
56 #include "llvm/IR/Constants.h"
57 #include "llvm/IR/DataLayout.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/IRBuilder.h"
61 #include "llvm/IR/InstrTypes.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/Instructions.h"
64 #include "llvm/IR/Metadata.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/NoFolder.h"
67 #include "llvm/IR/PassManager.h"
68 #include "llvm/IR/Type.h"
69 #include "llvm/IR/Use.h"
70 #include "llvm/IR/User.h"
71 #include "llvm/IR/Value.h"
72 #include "llvm/InitializePasses.h"
73 #include "llvm/Pass.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/Debug.h"
76 #include "llvm/Support/raw_ostream.h"
77 #include "llvm/Transforms/IPO.h"
78 #include <algorithm>
79 #include <cassert>
80 #include <cstdint>
81 #include <utility>
82 #include <vector>
83 
84 using namespace llvm;
85 
86 #define DEBUG_TYPE "argpromotion"
87 
88 STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");
89 STATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted");
90 STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated");
91 
92 struct ArgPart {
93   Type *Ty;
94   Align Alignment;
95   /// A representative guaranteed-executed load instruction for use by
96   /// metadata transfer.
97   LoadInst *MustExecLoad;
98 };
99 using OffsetAndArgPart = std::pair<int64_t, ArgPart>;
100 
101 static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL,
102                             Value *Ptr, Type *ResElemTy, int64_t Offset) {
103   // For non-opaque pointers, try create a "nice" GEP if possible, otherwise
104   // fall back to an i8 GEP to a specific offset.
105   unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace();
106   APInt OrigOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset);
107   if (!Ptr->getType()->isOpaquePointerTy()) {
108     Type *OrigElemTy = Ptr->getType()->getNonOpaquePointerElementType();
109     if (OrigOffset == 0 && OrigElemTy == ResElemTy)
110       return Ptr;
111 
112     if (OrigElemTy->isSized()) {
113       APInt TmpOffset = OrigOffset;
114       Type *TmpTy = OrigElemTy;
115       SmallVector<APInt> IntIndices =
116           DL.getGEPIndicesForOffset(TmpTy, TmpOffset);
117       if (TmpOffset == 0) {
118         // Try to add trailing zero indices to reach the right type.
119         while (TmpTy != ResElemTy) {
120           Type *NextTy = GetElementPtrInst::getTypeAtIndex(TmpTy, (uint64_t)0);
121           if (!NextTy)
122             break;
123 
124           IntIndices.push_back(APInt::getZero(
125               isa<StructType>(TmpTy) ? 32 : OrigOffset.getBitWidth()));
126           TmpTy = NextTy;
127         }
128 
129         SmallVector<Value *> Indices;
130         for (const APInt &Index : IntIndices)
131           Indices.push_back(IRB.getInt(Index));
132 
133         if (OrigOffset != 0 || TmpTy == ResElemTy) {
134           Ptr = IRB.CreateGEP(OrigElemTy, Ptr, Indices);
135           return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
136         }
137       }
138     }
139   }
140 
141   if (OrigOffset != 0) {
142     Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy(AddrSpace));
143     Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(OrigOffset));
144   }
145   return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
146 }
147 
148 /// DoPromotion - This method actually performs the promotion of the specified
149 /// arguments, and returns the new function.  At this point, we know that it's
150 /// safe to do so.
151 static Function *doPromotion(
152     Function *F,
153     const DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> &ArgsToPromote,
154     SmallPtrSetImpl<Argument *> &ByValArgsToTransform,
155     Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>>
156         ReplaceCallSite) {
157   // Start by computing a new prototype for the function, which is the same as
158   // the old function, but has modified arguments.
159   FunctionType *FTy = F->getFunctionType();
160   std::vector<Type *> Params;
161 
162   // Attribute - Keep track of the parameter attributes for the arguments
163   // that we are *not* promoting. For the ones that we do promote, the parameter
164   // attributes are lost
165   SmallVector<AttributeSet, 8> ArgAttrVec;
166   AttributeList PAL = F->getAttributes();
167 
168   // First, determine the new argument list
169   unsigned ArgNo = 0;
170   for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
171        ++I, ++ArgNo) {
172     if (ByValArgsToTransform.count(&*I)) {
173       // Simple byval argument? Just add all the struct element types.
174       Type *AgTy = I->getParamByValType();
175       StructType *STy = cast<StructType>(AgTy);
176       llvm::append_range(Params, STy->elements());
177       ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(),
178                         AttributeSet());
179       ++NumByValArgsPromoted;
180     } else if (!ArgsToPromote.count(&*I)) {
181       // Unchanged argument
182       Params.push_back(I->getType());
183       ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo));
184     } else if (I->use_empty()) {
185       // Dead argument (which are always marked as promotable)
186       ++NumArgumentsDead;
187     } else {
188       const auto &ArgParts = ArgsToPromote.find(&*I)->second;
189       for (const auto &Pair : ArgParts) {
190         Params.push_back(Pair.second.Ty);
191         ArgAttrVec.push_back(AttributeSet());
192       }
193       ++NumArgumentsPromoted;
194     }
195   }
196 
197   Type *RetTy = FTy->getReturnType();
198 
199   // Construct the new function type using the new arguments.
200   FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
201 
202   // Create the new function body and insert it into the module.
203   Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(),
204                                   F->getName());
205   NF->copyAttributesFrom(F);
206   NF->copyMetadata(F, 0);
207 
208   // The new function will have the !dbg metadata copied from the original
209   // function. The original function may not be deleted, and dbg metadata need
210   // to be unique so we need to drop it.
211   F->setSubprogram(nullptr);
212 
213   LLVM_DEBUG(dbgs() << "ARG PROMOTION:  Promoting to:" << *NF << "\n"
214                     << "From: " << *F);
215 
216   // Recompute the parameter attributes list based on the new arguments for
217   // the function.
218   NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(),
219                                        PAL.getRetAttrs(), ArgAttrVec));
220   ArgAttrVec.clear();
221 
222   F->getParent()->getFunctionList().insert(F->getIterator(), NF);
223   NF->takeName(F);
224 
225   // Loop over all of the callers of the function, transforming the call sites
226   // to pass in the loaded pointers.
227   //
228   SmallVector<Value *, 16> Args;
229   const DataLayout &DL = F->getParent()->getDataLayout();
230   while (!F->use_empty()) {
231     CallBase &CB = cast<CallBase>(*F->user_back());
232     assert(CB.getCalledFunction() == F);
233     const AttributeList &CallPAL = CB.getAttributes();
234     IRBuilder<NoFolder> IRB(&CB);
235 
236     // Loop over the operands, inserting GEP and loads in the caller as
237     // appropriate.
238     auto AI = CB.arg_begin();
239     ArgNo = 0;
240     for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
241          ++I, ++AI, ++ArgNo)
242       if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) {
243         Args.push_back(*AI); // Unmodified argument
244         ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
245       } else if (ByValArgsToTransform.count(&*I)) {
246         // Emit a GEP and load for each element of the struct.
247         Type *AgTy = I->getParamByValType();
248         StructType *STy = cast<StructType>(AgTy);
249         Value *Idxs[2] = {
250             ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr};
251         const StructLayout *SL = DL.getStructLayout(STy);
252         Align StructAlign = *I->getParamAlign();
253         for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
254           Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
255           auto *Idx =
256               IRB.CreateGEP(STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i));
257           // TODO: Tell AA about the new values?
258           Align Alignment =
259               commonAlignment(StructAlign, SL->getElementOffset(i));
260           Args.push_back(IRB.CreateAlignedLoad(
261               STy->getElementType(i), Idx, Alignment, Idx->getName() + ".val"));
262           ArgAttrVec.push_back(AttributeSet());
263         }
264       } else if (!I->use_empty()) {
265         Value *V = *AI;
266         const auto &ArgParts = ArgsToPromote.find(&*I)->second;
267         for (const auto &Pair : ArgParts) {
268           LoadInst *LI = IRB.CreateAlignedLoad(
269               Pair.second.Ty,
270               createByteGEP(IRB, DL, V, Pair.second.Ty, Pair.first),
271               Pair.second.Alignment, V->getName() + ".val");
272           if (Pair.second.MustExecLoad) {
273             LI->setAAMetadata(Pair.second.MustExecLoad->getAAMetadata());
274             LI->copyMetadata(*Pair.second.MustExecLoad,
275                              {LLVMContext::MD_range, LLVMContext::MD_nonnull,
276                               LLVMContext::MD_dereferenceable,
277                               LLVMContext::MD_dereferenceable_or_null,
278                               LLVMContext::MD_align, LLVMContext::MD_noundef});
279           }
280           Args.push_back(LI);
281           ArgAttrVec.push_back(AttributeSet());
282         }
283       }
284 
285     // Push any varargs arguments on the list.
286     for (; AI != CB.arg_end(); ++AI, ++ArgNo) {
287       Args.push_back(*AI);
288       ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
289     }
290 
291     SmallVector<OperandBundleDef, 1> OpBundles;
292     CB.getOperandBundlesAsDefs(OpBundles);
293 
294     CallBase *NewCS = nullptr;
295     if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) {
296       NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
297                                  Args, OpBundles, "", &CB);
298     } else {
299       auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", &CB);
300       NewCall->setTailCallKind(cast<CallInst>(&CB)->getTailCallKind());
301       NewCS = NewCall;
302     }
303     NewCS->setCallingConv(CB.getCallingConv());
304     NewCS->setAttributes(AttributeList::get(F->getContext(),
305                                             CallPAL.getFnAttrs(),
306                                             CallPAL.getRetAttrs(), ArgAttrVec));
307     NewCS->copyMetadata(CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg});
308     Args.clear();
309     ArgAttrVec.clear();
310 
311     // Update the callgraph to know that the callsite has been transformed.
312     if (ReplaceCallSite)
313       (*ReplaceCallSite)(CB, *NewCS);
314 
315     if (!CB.use_empty()) {
316       CB.replaceAllUsesWith(NewCS);
317       NewCS->takeName(&CB);
318     }
319 
320     // Finally, remove the old call from the program, reducing the use-count of
321     // F.
322     CB.eraseFromParent();
323   }
324 
325   // Since we have now created the new function, splice the body of the old
326   // function right into the new function, leaving the old rotting hulk of the
327   // function empty.
328   NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
329 
330   // Loop over the argument list, transferring uses of the old arguments over to
331   // the new arguments, also transferring over the names as well.
332   Function::arg_iterator I2 = NF->arg_begin();
333   for (Argument &Arg : F->args()) {
334     if (!ArgsToPromote.count(&Arg) && !ByValArgsToTransform.count(&Arg)) {
335       // If this is an unmodified argument, move the name and users over to the
336       // new version.
337       Arg.replaceAllUsesWith(&*I2);
338       I2->takeName(&Arg);
339       ++I2;
340       continue;
341     }
342 
343     if (ByValArgsToTransform.count(&Arg)) {
344       // In the callee, we create an alloca, and store each of the new incoming
345       // arguments into the alloca.
346       Instruction *InsertPt = &NF->begin()->front();
347 
348       // Just add all the struct element types.
349       Type *AgTy = Arg.getParamByValType();
350       Align StructAlign = *Arg.getParamAlign();
351       Value *TheAlloca = new AllocaInst(AgTy, DL.getAllocaAddrSpace(), nullptr,
352                                         StructAlign, "", InsertPt);
353       StructType *STy = cast<StructType>(AgTy);
354       Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(F->getContext()), 0),
355                         nullptr};
356       const StructLayout *SL = DL.getStructLayout(STy);
357 
358       for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
359         Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
360         Value *Idx = GetElementPtrInst::Create(
361             AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i),
362             InsertPt);
363         I2->setName(Arg.getName() + "." + Twine(i));
364         Align Alignment = commonAlignment(StructAlign, SL->getElementOffset(i));
365         new StoreInst(&*I2++, Idx, false, Alignment, InsertPt);
366       }
367 
368       // Anything that used the arg should now use the alloca.
369       Arg.replaceAllUsesWith(TheAlloca);
370       TheAlloca->takeName(&Arg);
371       continue;
372     }
373 
374     // There potentially are metadata uses for things like llvm.dbg.value.
375     // Replace them with undef, after handling the other regular uses.
376     auto RauwUndefMetadata = make_scope_exit(
377         [&]() { Arg.replaceAllUsesWith(UndefValue::get(Arg.getType())); });
378 
379     if (Arg.use_empty())
380       continue;
381 
382     SmallDenseMap<int64_t, Argument *> OffsetToArg;
383     for (const auto &Pair : ArgsToPromote.find(&Arg)->second) {
384       Argument &NewArg = *I2++;
385       NewArg.setName(Arg.getName() + "." + Twine(Pair.first) + ".val");
386       OffsetToArg.insert({Pair.first, &NewArg});
387     }
388 
389     // Otherwise, if we promoted this argument, then all users are load
390     // instructions (with possible casts and GEPs in between).
391 
392     SmallVector<Value *, 16> Worklist;
393     SmallVector<Instruction *, 16> DeadInsts;
394     append_range(Worklist, Arg.users());
395     while (!Worklist.empty()) {
396       Value *V = Worklist.pop_back_val();
397       if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V)) {
398         DeadInsts.push_back(cast<Instruction>(V));
399         append_range(Worklist, V->users());
400         continue;
401       }
402 
403       if (auto *LI = dyn_cast<LoadInst>(V)) {
404         Value *Ptr = LI->getPointerOperand();
405         APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
406         Ptr =
407             Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
408                                                    /* AllowNonInbounds */ true);
409         assert(Ptr == &Arg && "Not constant offset from arg?");
410         LI->replaceAllUsesWith(OffsetToArg[Offset.getSExtValue()]);
411         DeadInsts.push_back(LI);
412         continue;
413       }
414 
415       llvm_unreachable("Unexpected user");
416     }
417 
418     for (Instruction *I : DeadInsts) {
419       I->replaceAllUsesWith(UndefValue::get(I->getType()));
420       I->eraseFromParent();
421     }
422   }
423 
424   return NF;
425 }
426 
427 /// Return true if we can prove that all callees pass in a valid pointer for the
428 /// specified function argument.
429 static bool allCallersPassValidPointerForArgument(Argument *Arg,
430                                                   Align NeededAlign,
431                                                   uint64_t NeededDerefBytes) {
432   Function *Callee = Arg->getParent();
433   const DataLayout &DL = Callee->getParent()->getDataLayout();
434   APInt Bytes(64, NeededDerefBytes);
435 
436   // Check if the argument itself is marked dereferenceable and aligned.
437   if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
438     return true;
439 
440   // Look at all call sites of the function.  At this point we know we only have
441   // direct callees.
442   return all_of(Callee->users(), [&](User *U) {
443     CallBase &CB = cast<CallBase>(*U);
444     return isDereferenceableAndAlignedPointer(
445         CB.getArgOperand(Arg->getArgNo()), NeededAlign, Bytes, DL);
446   });
447 }
448 
449 /// Determine that this argument is safe to promote, and find the argument
450 /// parts it can be promoted into.
451 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
452                          unsigned MaxElements, bool IsRecursive,
453                          SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
454   // Quick exit for unused arguments
455   if (Arg->use_empty())
456     return true;
457 
458   // We can only promote this argument if all of the uses are loads at known
459   // offsets.
460   //
461   // Promoting the argument causes it to be loaded in the caller
462   // unconditionally. This is only safe if we can prove that either the load
463   // would have happened in the callee anyway (ie, there is a load in the entry
464   // block) or the pointer passed in at every call site is guaranteed to be
465   // valid.
466   // In the former case, invalid loads can happen, but would have happened
467   // anyway, in the latter case, invalid loads won't happen. This prevents us
468   // from introducing an invalid load that wouldn't have happened in the
469   // original code.
470 
471   SmallDenseMap<int64_t, ArgPart, 4> ArgParts;
472   Align NeededAlign(1);
473   uint64_t NeededDerefBytes = 0;
474 
475   // Returns None if this load is not based on the argument. Return true if
476   // we can promote the load, false otherwise.
477   auto HandleLoad = [&](LoadInst *LI,
478                         bool GuaranteedToExecute) -> Optional<bool> {
479     // Don't promote volatile or atomic loads.
480     if (!LI->isSimple())
481       return false;
482 
483     Value *Ptr = LI->getPointerOperand();
484     APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
485     Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
486                                                  /* AllowNonInbounds */ true);
487     if (Ptr != Arg)
488       return None;
489 
490     if (Offset.getSignificantBits() >= 64)
491       return false;
492 
493     Type *Ty = LI->getType();
494     TypeSize Size = DL.getTypeStoreSize(Ty);
495     // Don't try to promote scalable types.
496     if (Size.isScalable())
497       return false;
498 
499     // If this is a recursive function and one of the types is a pointer,
500     // then promoting it might lead to recursive promotion.
501     if (IsRecursive && Ty->isPointerTy())
502       return false;
503 
504     int64_t Off = Offset.getSExtValue();
505     auto Pair = ArgParts.try_emplace(
506         Off, ArgPart{Ty, LI->getAlign(), GuaranteedToExecute ? LI : nullptr});
507     ArgPart &Part = Pair.first->second;
508     bool OffsetNotSeenBefore = Pair.second;
509 
510     // We limit promotion to only promoting up to a fixed number of elements of
511     // the aggregate.
512     if (MaxElements > 0 && ArgParts.size() >= MaxElements) {
513       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
514                         << "more than " << MaxElements << " parts\n");
515       return false;
516     }
517 
518     // For now, we only support loading one specific type at a given offset.
519     if (Part.Ty != Ty) {
520       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
521                         << "loaded via both " << *Part.Ty << " and " << *Ty
522                         << " at offset " << Off << "\n");
523       return false;
524     }
525 
526     // If this load is not guaranteed to execute and we haven't seen a load at
527     // this offset before (or it had lower alignment), then we need to remember
528     // that requirement.
529     // Note that skipping loads of previously seen offsets is only correct
530     // because we only allow a single type for a given offset, which also means
531     // that the number of accessed bytes will be the same.
532     if (!GuaranteedToExecute &&
533         (OffsetNotSeenBefore || Part.Alignment < LI->getAlign())) {
534       // We won't be able to prove dereferenceability for negative offsets.
535       if (Off < 0)
536         return false;
537 
538       // If the offset is not aligned, an aligned base pointer won't help.
539       if (!isAligned(LI->getAlign(), Off))
540         return false;
541 
542       NeededDerefBytes = std::max(NeededDerefBytes, Off + Size.getFixedValue());
543       NeededAlign = std::max(NeededAlign, LI->getAlign());
544     }
545 
546     Part.Alignment = std::max(Part.Alignment, LI->getAlign());
547     return true;
548   };
549 
550   // Look for loads that are guaranteed to execute on entry.
551   for (Instruction &I : Arg->getParent()->getEntryBlock()) {
552     if (LoadInst *LI = dyn_cast<LoadInst>(&I))
553       if (Optional<bool> Res = HandleLoad(LI, /* GuaranteedToExecute */ true))
554         if (!*Res)
555           return false;
556 
557     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
558       break;
559   }
560 
561   // Now look at all loads of the argument. Remember the load instructions
562   // for the aliasing check below.
563   SmallVector<Value *, 16> Worklist;
564   SmallPtrSet<Value *, 16> Visited;
565   SmallVector<LoadInst *, 16> Loads;
566   auto AppendUsers = [&](Value *V) {
567     for (User *U : V->users())
568       if (Visited.insert(U).second)
569         Worklist.push_back(U);
570   };
571   AppendUsers(Arg);
572   while (!Worklist.empty()) {
573     Value *V = Worklist.pop_back_val();
574     if (isa<BitCastInst>(V)) {
575       AppendUsers(V);
576       continue;
577     }
578 
579     if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
580       if (!GEP->hasAllConstantIndices())
581         return false;
582       AppendUsers(V);
583       continue;
584     }
585 
586     if (auto *LI = dyn_cast<LoadInst>(V)) {
587       if (!*HandleLoad(LI, /* GuaranteedToExecute */ false))
588         return false;
589       Loads.push_back(LI);
590       continue;
591     }
592 
593     // Unknown user.
594     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
595                       << "unknown user " << *V << "\n");
596     return false;
597   }
598 
599   if (NeededDerefBytes || NeededAlign > 1) {
600     // Try to prove a required deref / aligned requirement.
601     if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
602                                                NeededDerefBytes)) {
603       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
604                         << "not dereferenceable or aligned\n");
605       return false;
606     }
607   }
608 
609   if (ArgParts.empty())
610     return true; // No users, this is a dead argument.
611 
612   // Sort parts by offset.
613   append_range(ArgPartsVec, ArgParts);
614   sort(ArgPartsVec,
615        [](const auto &A, const auto &B) { return A.first < B.first; });
616 
617   // Make sure the parts are non-overlapping.
618   // TODO: As we're doing pure load promotion here, overlap should be fine from
619   // a correctness perspective. Profitability is less obvious though.
620   int64_t Offset = ArgPartsVec[0].first;
621   for (const auto &Pair : ArgPartsVec) {
622     if (Pair.first < Offset)
623       return false; // Overlap with previous part.
624 
625     Offset = Pair.first + DL.getTypeStoreSize(Pair.second.Ty);
626   }
627 
628   // Okay, now we know that the argument is only used by load instructions and
629   // it is safe to unconditionally perform all of them. Use alias analysis to
630   // check to see if the pointer is guaranteed to not be modified from entry of
631   // the function to each of the load instructions.
632 
633   // Because there could be several/many load instructions, remember which
634   // blocks we know to be transparent to the load.
635   df_iterator_default_set<BasicBlock *, 16> TranspBlocks;
636 
637   for (LoadInst *Load : Loads) {
638     // Check to see if the load is invalidated from the start of the block to
639     // the load itself.
640     BasicBlock *BB = Load->getParent();
641 
642     MemoryLocation Loc = MemoryLocation::get(Load);
643     if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod))
644       return false; // Pointer is invalidated!
645 
646     // Now check every path from the entry block to the load for transparency.
647     // To do this, we perform a depth first search on the inverse CFG from the
648     // loading block.
649     for (BasicBlock *P : predecessors(BB)) {
650       for (BasicBlock *TranspBB : inverse_depth_first_ext(P, TranspBlocks))
651         if (AAR.canBasicBlockModify(*TranspBB, Loc))
652           return false;
653     }
654   }
655 
656   // If the path from the entry of the function to each load is free of
657   // instructions that potentially invalidate the load, we can make the
658   // transformation!
659   return true;
660 }
661 
662 bool ArgumentPromotionPass::isDenselyPacked(Type *type, const DataLayout &DL) {
663   // There is no size information, so be conservative.
664   if (!type->isSized())
665     return false;
666 
667   // If the alloc size is not equal to the storage size, then there are padding
668   // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128.
669   if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type))
670     return false;
671 
672   // FIXME: This isn't the right way to check for padding in vectors with
673   // non-byte-size elements.
674   if (VectorType *seqTy = dyn_cast<VectorType>(type))
675     return isDenselyPacked(seqTy->getElementType(), DL);
676 
677   // For array types, check for padding within members.
678   if (ArrayType *seqTy = dyn_cast<ArrayType>(type))
679     return isDenselyPacked(seqTy->getElementType(), DL);
680 
681   if (!isa<StructType>(type))
682     return true;
683 
684   // Check for padding within and between elements of a struct.
685   StructType *StructTy = cast<StructType>(type);
686   const StructLayout *Layout = DL.getStructLayout(StructTy);
687   uint64_t StartPos = 0;
688   for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) {
689     Type *ElTy = StructTy->getElementType(i);
690     if (!isDenselyPacked(ElTy, DL))
691       return false;
692     if (StartPos != Layout->getElementOffsetInBits(i))
693       return false;
694     StartPos += DL.getTypeAllocSizeInBits(ElTy);
695   }
696 
697   return true;
698 }
699 
700 /// Checks if the padding bytes of an argument could be accessed.
701 static bool canPaddingBeAccessed(Argument *arg) {
702   assert(arg->hasByValAttr());
703 
704   // Track all the pointers to the argument to make sure they are not captured.
705   SmallPtrSet<Value *, 16> PtrValues;
706   PtrValues.insert(arg);
707 
708   // Track all of the stores.
709   SmallVector<StoreInst *, 16> Stores;
710 
711   // Scan through the uses recursively to make sure the pointer is always used
712   // sanely.
713   SmallVector<Value *, 16> WorkList(arg->users());
714   while (!WorkList.empty()) {
715     Value *V = WorkList.pop_back_val();
716     if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) {
717       if (PtrValues.insert(V).second)
718         llvm::append_range(WorkList, V->users());
719     } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) {
720       Stores.push_back(Store);
721     } else if (!isa<LoadInst>(V)) {
722       return true;
723     }
724   }
725 
726   // Check to make sure the pointers aren't captured
727   for (StoreInst *Store : Stores)
728     if (PtrValues.count(Store->getValueOperand()))
729       return true;
730 
731   return false;
732 }
733 
734 /// Check if callers and callee agree on how promoted arguments would be
735 /// passed.
736 static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
737                                   const TargetTransformInfo &TTI) {
738   return all_of(F.uses(), [&](const Use &U) {
739     CallBase *CB = dyn_cast<CallBase>(U.getUser());
740     if (!CB)
741       return false;
742 
743     const Function *Caller = CB->getCaller();
744     const Function *Callee = CB->getCalledFunction();
745     return TTI.areTypesABICompatible(Caller, Callee, Types);
746   });
747 }
748 
749 /// PromoteArguments - This method checks the specified function to see if there
750 /// are any promotable arguments and if it is safe to promote the function (for
751 /// example, all callers are direct).  If safe to promote some arguments, it
752 /// calls the DoPromotion method.
753 static Function *
754 promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter,
755                  unsigned MaxElements,
756                  Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>>
757                      ReplaceCallSite,
758                  const TargetTransformInfo &TTI, bool IsRecursive) {
759   // Don't perform argument promotion for naked functions; otherwise we can end
760   // up removing parameters that are seemingly 'not used' as they are referred
761   // to in the assembly.
762   if(F->hasFnAttribute(Attribute::Naked))
763     return nullptr;
764 
765   // Make sure that it is local to this module.
766   if (!F->hasLocalLinkage())
767     return nullptr;
768 
769   // Don't promote arguments for variadic functions. Adding, removing, or
770   // changing non-pack parameters can change the classification of pack
771   // parameters. Frontends encode that classification at the call site in the
772   // IR, while in the callee the classification is determined dynamically based
773   // on the number of registers consumed so far.
774   if (F->isVarArg())
775     return nullptr;
776 
777   // Don't transform functions that receive inallocas, as the transformation may
778   // not be safe depending on calling convention.
779   if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca))
780     return nullptr;
781 
782   // First check: see if there are any pointer arguments!  If not, quick exit.
783   SmallVector<Argument *, 16> PointerArgs;
784   for (Argument &I : F->args())
785     if (I.getType()->isPointerTy())
786       PointerArgs.push_back(&I);
787   if (PointerArgs.empty())
788     return nullptr;
789 
790   // Second check: make sure that all callers are direct callers.  We can't
791   // transform functions that have indirect callers.  Also see if the function
792   // is self-recursive.
793   for (Use &U : F->uses()) {
794     CallBase *CB = dyn_cast<CallBase>(U.getUser());
795     // Must be a direct call.
796     if (CB == nullptr || !CB->isCallee(&U) ||
797         CB->getFunctionType() != F->getFunctionType())
798       return nullptr;
799 
800     // Can't change signature of musttail callee
801     if (CB->isMustTailCall())
802       return nullptr;
803 
804     if (CB->getParent()->getParent() == F)
805       IsRecursive = true;
806   }
807 
808   // Can't change signature of musttail caller
809   // FIXME: Support promoting whole chain of musttail functions
810   for (BasicBlock &BB : *F)
811     if (BB.getTerminatingMustTailCall())
812       return nullptr;
813 
814   const DataLayout &DL = F->getParent()->getDataLayout();
815 
816   AAResults &AAR = AARGetter(*F);
817 
818   // Check to see which arguments are promotable.  If an argument is promotable,
819   // add it to ArgsToPromote.
820   DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote;
821   SmallPtrSet<Argument *, 8> ByValArgsToTransform;
822   for (Argument *PtrArg : PointerArgs) {
823     // Replace sret attribute with noalias. This reduces register pressure by
824     // avoiding a register copy.
825     if (PtrArg->hasStructRetAttr()) {
826       unsigned ArgNo = PtrArg->getArgNo();
827       F->removeParamAttr(ArgNo, Attribute::StructRet);
828       F->addParamAttr(ArgNo, Attribute::NoAlias);
829       for (Use &U : F->uses()) {
830         CallBase &CB = cast<CallBase>(*U.getUser());
831         CB.removeParamAttr(ArgNo, Attribute::StructRet);
832         CB.addParamAttr(ArgNo, Attribute::NoAlias);
833       }
834     }
835 
836     // If this is a byval argument, and if the aggregate type is small, just
837     // pass the elements, which is always safe, if the passed value is densely
838     // packed or if we can prove the padding bytes are never accessed.
839     //
840     // Only handle arguments with specified alignment; if it's unspecified, the
841     // actual alignment of the argument is target-specific.
842     Type *ByValTy = PtrArg->getParamByValType();
843     bool isSafeToPromote =
844         ByValTy && PtrArg->getParamAlign() &&
845         (ArgumentPromotionPass::isDenselyPacked(ByValTy, DL) ||
846          !canPaddingBeAccessed(PtrArg));
847     if (isSafeToPromote) {
848       if (StructType *STy = dyn_cast<StructType>(ByValTy)) {
849         if (MaxElements > 0 && STy->getNumElements() > MaxElements) {
850           LLVM_DEBUG(dbgs() << "argpromotion disable promoting argument '"
851                             << PtrArg->getName()
852                             << "' because it would require adding more"
853                             << " than " << MaxElements
854                             << " arguments to the function.\n");
855           continue;
856         }
857 
858         SmallVector<Type *, 4> Types;
859         append_range(Types, STy->elements());
860 
861         // If all the elements are single-value types, we can promote it.
862         bool AllSimple =
863             all_of(Types, [](Type *Ty) { return Ty->isSingleValueType(); });
864 
865         // Safe to transform, don't even bother trying to "promote" it.
866         // Passing the elements as a scalar will allow sroa to hack on
867         // the new alloca we introduce.
868         if (AllSimple && areTypesABICompatible(Types, *F, TTI)) {
869           ByValArgsToTransform.insert(PtrArg);
870           continue;
871         }
872       }
873     }
874 
875     // Otherwise, see if we can promote the pointer to its value.
876     SmallVector<OffsetAndArgPart, 4> ArgParts;
877     if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
878       SmallVector<Type *, 4> Types;
879       for (const auto &Pair : ArgParts)
880         Types.push_back(Pair.second.Ty);
881 
882       if (areTypesABICompatible(Types, *F, TTI))
883         ArgsToPromote.insert({PtrArg, std::move(ArgParts)});
884     }
885   }
886 
887   // No promotable pointer arguments.
888   if (ArgsToPromote.empty() && ByValArgsToTransform.empty())
889     return nullptr;
890 
891   return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite);
892 }
893 
894 PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C,
895                                              CGSCCAnalysisManager &AM,
896                                              LazyCallGraph &CG,
897                                              CGSCCUpdateResult &UR) {
898   bool Changed = false, LocalChange;
899 
900   // Iterate until we stop promoting from this SCC.
901   do {
902     LocalChange = false;
903 
904     FunctionAnalysisManager &FAM =
905         AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
906 
907     bool IsRecursive = C.size() > 1;
908     for (LazyCallGraph::Node &N : C) {
909       Function &OldF = N.getFunction();
910 
911       // FIXME: This lambda must only be used with this function. We should
912       // skip the lambda and just get the AA results directly.
913       auto AARGetter = [&](Function &F) -> AAResults & {
914         assert(&F == &OldF && "Called with an unexpected function!");
915         return FAM.getResult<AAManager>(F);
916       };
917 
918       const TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(OldF);
919       Function *NewF = promoteArguments(&OldF, AARGetter, MaxElements, None,
920                                         TTI, IsRecursive);
921       if (!NewF)
922         continue;
923       LocalChange = true;
924 
925       // Directly substitute the functions in the call graph. Note that this
926       // requires the old function to be completely dead and completely
927       // replaced by the new function. It does no call graph updates, it merely
928       // swaps out the particular function mapped to a particular node in the
929       // graph.
930       C.getOuterRefSCC().replaceNodeFunction(N, *NewF);
931       FAM.clear(OldF, OldF.getName());
932       OldF.eraseFromParent();
933 
934       PreservedAnalyses FuncPA;
935       FuncPA.preserveSet<CFGAnalyses>();
936       for (auto *U : NewF->users()) {
937         auto *UserF = cast<CallBase>(U)->getFunction();
938         FAM.invalidate(*UserF, FuncPA);
939       }
940     }
941 
942     Changed |= LocalChange;
943   } while (LocalChange);
944 
945   if (!Changed)
946     return PreservedAnalyses::all();
947 
948   PreservedAnalyses PA;
949   // We've cleared out analyses for deleted functions.
950   PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
951   // We've manually invalidated analyses for functions we've modified.
952   PA.preserveSet<AllAnalysesOn<Function>>();
953   return PA;
954 }
955 
956 namespace {
957 
958 /// ArgPromotion - The 'by reference' to 'by value' argument promotion pass.
959 struct ArgPromotion : public CallGraphSCCPass {
960   // Pass identification, replacement for typeid
961   static char ID;
962 
963   explicit ArgPromotion(unsigned MaxElements = 3)
964       : CallGraphSCCPass(ID), MaxElements(MaxElements) {
965     initializeArgPromotionPass(*PassRegistry::getPassRegistry());
966   }
967 
968   void getAnalysisUsage(AnalysisUsage &AU) const override {
969     AU.addRequired<AssumptionCacheTracker>();
970     AU.addRequired<TargetLibraryInfoWrapperPass>();
971     AU.addRequired<TargetTransformInfoWrapperPass>();
972     getAAResultsAnalysisUsage(AU);
973     CallGraphSCCPass::getAnalysisUsage(AU);
974   }
975 
976   bool runOnSCC(CallGraphSCC &SCC) override;
977 
978 private:
979   using llvm::Pass::doInitialization;
980 
981   bool doInitialization(CallGraph &CG) override;
982 
983   /// The maximum number of elements to expand, or 0 for unlimited.
984   unsigned MaxElements;
985 };
986 
987 } // end anonymous namespace
988 
989 char ArgPromotion::ID = 0;
990 
991 INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion",
992                       "Promote 'by reference' arguments to scalars", false,
993                       false)
994 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
995 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
996 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
997 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
998 INITIALIZE_PASS_END(ArgPromotion, "argpromotion",
999                     "Promote 'by reference' arguments to scalars", false, false)
1000 
1001 Pass *llvm::createArgumentPromotionPass(unsigned MaxElements) {
1002   return new ArgPromotion(MaxElements);
1003 }
1004 
1005 bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) {
1006   if (skipSCC(SCC))
1007     return false;
1008 
1009   // Get the callgraph information that we need to update to reflect our
1010   // changes.
1011   CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
1012 
1013   LegacyAARGetter AARGetter(*this);
1014 
1015   bool Changed = false, LocalChange;
1016 
1017   // Iterate until we stop promoting from this SCC.
1018   do {
1019     LocalChange = false;
1020     // Attempt to promote arguments from all functions in this SCC.
1021     bool IsRecursive = SCC.size() > 1;
1022     for (CallGraphNode *OldNode : SCC) {
1023       Function *OldF = OldNode->getFunction();
1024       if (!OldF)
1025         continue;
1026 
1027       auto ReplaceCallSite = [&](CallBase &OldCS, CallBase &NewCS) {
1028         Function *Caller = OldCS.getParent()->getParent();
1029         CallGraphNode *NewCalleeNode =
1030             CG.getOrInsertFunction(NewCS.getCalledFunction());
1031         CallGraphNode *CallerNode = CG[Caller];
1032         CallerNode->replaceCallEdge(cast<CallBase>(OldCS),
1033                                     cast<CallBase>(NewCS), NewCalleeNode);
1034       };
1035 
1036       const TargetTransformInfo &TTI =
1037           getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*OldF);
1038       if (Function *NewF =
1039               promoteArguments(OldF, AARGetter, MaxElements, {ReplaceCallSite},
1040                                TTI, IsRecursive)) {
1041         LocalChange = true;
1042 
1043         // Update the call graph for the newly promoted function.
1044         CallGraphNode *NewNode = CG.getOrInsertFunction(NewF);
1045         NewNode->stealCalledFunctionsFrom(OldNode);
1046         if (OldNode->getNumReferences() == 0)
1047           delete CG.removeFunctionFromModule(OldNode);
1048         else
1049           OldF->setLinkage(Function::ExternalLinkage);
1050 
1051         // And updat ethe SCC we're iterating as well.
1052         SCC.ReplaceNode(OldNode, NewNode);
1053       }
1054     }
1055     // Remember that we changed something.
1056     Changed |= LocalChange;
1057   } while (LocalChange);
1058 
1059   return Changed;
1060 }
1061 
1062 bool ArgPromotion::doInitialization(CallGraph &CG) {
1063   return CallGraphSCCPass::doInitialization(CG);
1064 }
1065