1 //===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass promotes "by reference" arguments to be "by value" arguments.  In
10 // practice, this means looking for internal functions that have pointer
11 // arguments.  If it can prove, through the use of alias analysis, that an
12 // argument is *only* loaded, then it can pass the value into the function
13 // instead of the address of the value.  This can cause recursive simplification
14 // of code and lead to the elimination of allocas (especially in C++ template
15 // code like the STL).
16 //
17 // This pass also handles aggregate arguments that are passed into a function,
18 // scalarizing them if the elements of the aggregate are only loaded.  Note that
19 // by default it refuses to scalarize aggregates which would require passing in
20 // more than three operands to the function, because passing thousands of
21 // operands for a large array or structure is unprofitable! This limit can be
22 // configured or disabled, however.
23 //
24 // Note that this transformation could also be done for arguments that are only
25 // stored to (returning the value instead), but does not currently.  This case
26 // would be best handled when and if LLVM begins supporting multiple return
27 // values from functions.
28 //
29 //===----------------------------------------------------------------------===//
30 
31 #include "llvm/Transforms/IPO/ArgumentPromotion.h"
32 #include "llvm/ADT/DepthFirstIterator.h"
33 #include "llvm/ADT/None.h"
34 #include "llvm/ADT/Optional.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/SmallPtrSet.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/Statistic.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/Analysis/AssumptionCache.h"
42 #include "llvm/Analysis/BasicAliasAnalysis.h"
43 #include "llvm/Analysis/CGSCCPassManager.h"
44 #include "llvm/Analysis/CallGraph.h"
45 #include "llvm/Analysis/CallGraphSCCPass.h"
46 #include "llvm/Analysis/LazyCallGraph.h"
47 #include "llvm/Analysis/Loads.h"
48 #include "llvm/Analysis/MemoryLocation.h"
49 #include "llvm/Analysis/ValueTracking.h"
50 #include "llvm/Analysis/TargetLibraryInfo.h"
51 #include "llvm/Analysis/TargetTransformInfo.h"
52 #include "llvm/IR/Argument.h"
53 #include "llvm/IR/Attributes.h"
54 #include "llvm/IR/BasicBlock.h"
55 #include "llvm/IR/CFG.h"
56 #include "llvm/IR/Constants.h"
57 #include "llvm/IR/DataLayout.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/IRBuilder.h"
61 #include "llvm/IR/InstrTypes.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/Instructions.h"
64 #include "llvm/IR/Metadata.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/NoFolder.h"
67 #include "llvm/IR/PassManager.h"
68 #include "llvm/IR/Type.h"
69 #include "llvm/IR/Use.h"
70 #include "llvm/IR/User.h"
71 #include "llvm/IR/Value.h"
72 #include "llvm/InitializePasses.h"
73 #include "llvm/Pass.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/Debug.h"
76 #include "llvm/Support/FormatVariadic.h"
77 #include "llvm/Support/raw_ostream.h"
78 #include "llvm/Transforms/IPO.h"
79 #include <algorithm>
80 #include <cassert>
81 #include <cstdint>
82 #include <functional>
83 #include <iterator>
84 #include <map>
85 #include <set>
86 #include <utility>
87 #include <vector>
88 
89 using namespace llvm;
90 
91 #define DEBUG_TYPE "argpromotion"
92 
93 STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");
94 STATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted");
95 STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated");
96 
97 struct ArgPart {
98   Type *Ty;
99   Align Alignment;
100   /// A representative guaranteed-executed load instruction for use by
101   /// metadata transfer.
102   LoadInst *MustExecLoad;
103 };
104 using OffsetAndArgPart = std::pair<int64_t, ArgPart>;
105 
106 static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL,
107                             Value *Ptr, Type *ResElemTy, int64_t Offset) {
108   // For non-opaque pointers, try create a "nice" GEP if possible, otherwise
109   // fall back to an i8 GEP to a specific offset.
110   unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace();
111   APInt OrigOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset);
112   if (!Ptr->getType()->isOpaquePointerTy()) {
113     Type *OrigElemTy = Ptr->getType()->getNonOpaquePointerElementType();
114     if (OrigOffset == 0 && OrigElemTy == ResElemTy)
115       return Ptr;
116 
117     if (OrigElemTy->isSized()) {
118       APInt TmpOffset = OrigOffset;
119       Type *TmpTy = OrigElemTy;
120       SmallVector<APInt> IntIndices =
121           DL.getGEPIndicesForOffset(TmpTy, TmpOffset);
122       if (TmpOffset == 0) {
123         // Try to add trailing zero indices to reach the right type.
124         while (TmpTy != ResElemTy) {
125           Type *NextTy = GetElementPtrInst::getTypeAtIndex(TmpTy, (uint64_t)0);
126           if (!NextTy)
127             break;
128 
129           IntIndices.push_back(APInt::getZero(
130               isa<StructType>(TmpTy) ? 32 : OrigOffset.getBitWidth()));
131           TmpTy = NextTy;
132         }
133 
134         SmallVector<Value *> Indices;
135         for (const APInt &Index : IntIndices)
136           Indices.push_back(IRB.getInt(Index));
137 
138         if (OrigOffset != 0 || TmpTy == ResElemTy) {
139           Ptr = IRB.CreateGEP(OrigElemTy, Ptr, Indices);
140           return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
141         }
142       }
143     }
144   }
145 
146   if (OrigOffset != 0) {
147     Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy(AddrSpace));
148     Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(OrigOffset));
149   }
150   return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
151 }
152 
153 /// DoPromotion - This method actually performs the promotion of the specified
154 /// arguments, and returns the new function.  At this point, we know that it's
155 /// safe to do so.
156 static Function *doPromotion(
157     Function *F,
158     const DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> &ArgsToPromote,
159     SmallPtrSetImpl<Argument *> &ByValArgsToTransform,
160     Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>>
161         ReplaceCallSite) {
162   // Start by computing a new prototype for the function, which is the same as
163   // the old function, but has modified arguments.
164   FunctionType *FTy = F->getFunctionType();
165   std::vector<Type *> Params;
166 
167   // Attribute - Keep track of the parameter attributes for the arguments
168   // that we are *not* promoting. For the ones that we do promote, the parameter
169   // attributes are lost
170   SmallVector<AttributeSet, 8> ArgAttrVec;
171   AttributeList PAL = F->getAttributes();
172 
173   // First, determine the new argument list
174   unsigned ArgNo = 0;
175   for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
176        ++I, ++ArgNo) {
177     if (ByValArgsToTransform.count(&*I)) {
178       // Simple byval argument? Just add all the struct element types.
179       Type *AgTy = I->getParamByValType();
180       StructType *STy = cast<StructType>(AgTy);
181       llvm::append_range(Params, STy->elements());
182       ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(),
183                         AttributeSet());
184       ++NumByValArgsPromoted;
185     } else if (!ArgsToPromote.count(&*I)) {
186       // Unchanged argument
187       Params.push_back(I->getType());
188       ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo));
189     } else if (I->use_empty()) {
190       // Dead argument (which are always marked as promotable)
191       ++NumArgumentsDead;
192     } else {
193       const auto &ArgParts = ArgsToPromote.find(&*I)->second;
194       for (const auto &Pair : ArgParts) {
195         Params.push_back(Pair.second.Ty);
196         ArgAttrVec.push_back(AttributeSet());
197       }
198       ++NumArgumentsPromoted;
199     }
200   }
201 
202   Type *RetTy = FTy->getReturnType();
203 
204   // Construct the new function type using the new arguments.
205   FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
206 
207   // Create the new function body and insert it into the module.
208   Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(),
209                                   F->getName());
210   NF->copyAttributesFrom(F);
211   NF->copyMetadata(F, 0);
212 
213   // The new function will have the !dbg metadata copied from the original
214   // function. The original function may not be deleted, and dbg metadata need
215   // to be unique so we need to drop it.
216   F->setSubprogram(nullptr);
217 
218   LLVM_DEBUG(dbgs() << "ARG PROMOTION:  Promoting to:" << *NF << "\n"
219                     << "From: " << *F);
220 
221   // Recompute the parameter attributes list based on the new arguments for
222   // the function.
223   NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(),
224                                        PAL.getRetAttrs(), ArgAttrVec));
225   ArgAttrVec.clear();
226 
227   F->getParent()->getFunctionList().insert(F->getIterator(), NF);
228   NF->takeName(F);
229 
230   // Loop over all of the callers of the function, transforming the call sites
231   // to pass in the loaded pointers.
232   //
233   SmallVector<Value *, 16> Args;
234   const DataLayout &DL = F->getParent()->getDataLayout();
235   while (!F->use_empty()) {
236     CallBase &CB = cast<CallBase>(*F->user_back());
237     assert(CB.getCalledFunction() == F);
238     const AttributeList &CallPAL = CB.getAttributes();
239     IRBuilder<NoFolder> IRB(&CB);
240 
241     // Loop over the operands, inserting GEP and loads in the caller as
242     // appropriate.
243     auto AI = CB.arg_begin();
244     ArgNo = 0;
245     for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
246          ++I, ++AI, ++ArgNo)
247       if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) {
248         Args.push_back(*AI); // Unmodified argument
249         ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
250       } else if (ByValArgsToTransform.count(&*I)) {
251         // Emit a GEP and load for each element of the struct.
252         Type *AgTy = I->getParamByValType();
253         StructType *STy = cast<StructType>(AgTy);
254         Value *Idxs[2] = {
255             ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr};
256         const StructLayout *SL = DL.getStructLayout(STy);
257         Align StructAlign = *I->getParamAlign();
258         for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
259           Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
260           auto *Idx =
261               IRB.CreateGEP(STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i));
262           // TODO: Tell AA about the new values?
263           Align Alignment =
264               commonAlignment(StructAlign, SL->getElementOffset(i));
265           Args.push_back(IRB.CreateAlignedLoad(
266               STy->getElementType(i), Idx, Alignment, Idx->getName() + ".val"));
267           ArgAttrVec.push_back(AttributeSet());
268         }
269       } else if (!I->use_empty()) {
270         Value *V = *AI;
271         const auto &ArgParts = ArgsToPromote.find(&*I)->second;
272         for (const auto &Pair : ArgParts) {
273           LoadInst *LI = IRB.CreateAlignedLoad(
274               Pair.second.Ty,
275               createByteGEP(IRB, DL, V, Pair.second.Ty, Pair.first),
276               Pair.second.Alignment, V->getName() + ".val");
277           if (Pair.second.MustExecLoad) {
278             LI->setAAMetadata(Pair.second.MustExecLoad->getAAMetadata());
279             LI->copyMetadata(*Pair.second.MustExecLoad,
280                              {LLVMContext::MD_range, LLVMContext::MD_nonnull,
281                               LLVMContext::MD_dereferenceable,
282                               LLVMContext::MD_dereferenceable_or_null,
283                               LLVMContext::MD_align, LLVMContext::MD_noundef});
284           }
285           Args.push_back(LI);
286           ArgAttrVec.push_back(AttributeSet());
287         }
288       }
289 
290     // Push any varargs arguments on the list.
291     for (; AI != CB.arg_end(); ++AI, ++ArgNo) {
292       Args.push_back(*AI);
293       ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
294     }
295 
296     SmallVector<OperandBundleDef, 1> OpBundles;
297     CB.getOperandBundlesAsDefs(OpBundles);
298 
299     CallBase *NewCS = nullptr;
300     if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) {
301       NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
302                                  Args, OpBundles, "", &CB);
303     } else {
304       auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", &CB);
305       NewCall->setTailCallKind(cast<CallInst>(&CB)->getTailCallKind());
306       NewCS = NewCall;
307     }
308     NewCS->setCallingConv(CB.getCallingConv());
309     NewCS->setAttributes(AttributeList::get(F->getContext(),
310                                             CallPAL.getFnAttrs(),
311                                             CallPAL.getRetAttrs(), ArgAttrVec));
312     NewCS->copyMetadata(CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg});
313     Args.clear();
314     ArgAttrVec.clear();
315 
316     // Update the callgraph to know that the callsite has been transformed.
317     if (ReplaceCallSite)
318       (*ReplaceCallSite)(CB, *NewCS);
319 
320     if (!CB.use_empty()) {
321       CB.replaceAllUsesWith(NewCS);
322       NewCS->takeName(&CB);
323     }
324 
325     // Finally, remove the old call from the program, reducing the use-count of
326     // F.
327     CB.eraseFromParent();
328   }
329 
330   // Since we have now created the new function, splice the body of the old
331   // function right into the new function, leaving the old rotting hulk of the
332   // function empty.
333   NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
334 
335   // Loop over the argument list, transferring uses of the old arguments over to
336   // the new arguments, also transferring over the names as well.
337   Function::arg_iterator I2 = NF->arg_begin();
338   for (Argument &Arg : F->args()) {
339     if (!ArgsToPromote.count(&Arg) && !ByValArgsToTransform.count(&Arg)) {
340       // If this is an unmodified argument, move the name and users over to the
341       // new version.
342       Arg.replaceAllUsesWith(&*I2);
343       I2->takeName(&Arg);
344       ++I2;
345       continue;
346     }
347 
348     if (ByValArgsToTransform.count(&Arg)) {
349       // In the callee, we create an alloca, and store each of the new incoming
350       // arguments into the alloca.
351       Instruction *InsertPt = &NF->begin()->front();
352 
353       // Just add all the struct element types.
354       Type *AgTy = Arg.getParamByValType();
355       Align StructAlign = *Arg.getParamAlign();
356       Value *TheAlloca = new AllocaInst(AgTy, DL.getAllocaAddrSpace(), nullptr,
357                                         StructAlign, "", InsertPt);
358       StructType *STy = cast<StructType>(AgTy);
359       Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(F->getContext()), 0),
360                         nullptr};
361       const StructLayout *SL = DL.getStructLayout(STy);
362 
363       for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
364         Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
365         Value *Idx = GetElementPtrInst::Create(
366             AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i),
367             InsertPt);
368         I2->setName(Arg.getName() + "." + Twine(i));
369         Align Alignment = commonAlignment(StructAlign, SL->getElementOffset(i));
370         new StoreInst(&*I2++, Idx, false, Alignment, InsertPt);
371       }
372 
373       // Anything that used the arg should now use the alloca.
374       Arg.replaceAllUsesWith(TheAlloca);
375       TheAlloca->takeName(&Arg);
376       continue;
377     }
378 
379     // There potentially are metadata uses for things like llvm.dbg.value.
380     // Replace them with undef, after handling the other regular uses.
381     auto RauwUndefMetadata = make_scope_exit(
382         [&]() { Arg.replaceAllUsesWith(UndefValue::get(Arg.getType())); });
383 
384     if (Arg.use_empty())
385       continue;
386 
387     SmallDenseMap<int64_t, Argument *> OffsetToArg;
388     for (const auto &Pair : ArgsToPromote.find(&Arg)->second) {
389       Argument &NewArg = *I2++;
390       NewArg.setName(Arg.getName() + "." + Twine(Pair.first) + ".val");
391       OffsetToArg.insert({Pair.first, &NewArg});
392     }
393 
394     // Otherwise, if we promoted this argument, then all users are load
395     // instructions (with possible casts and GEPs in between).
396 
397     SmallVector<Value *, 16> Worklist;
398     SmallVector<Instruction *, 16> DeadInsts;
399     append_range(Worklist, Arg.users());
400     while (!Worklist.empty()) {
401       Value *V = Worklist.pop_back_val();
402       if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V)) {
403         DeadInsts.push_back(cast<Instruction>(V));
404         append_range(Worklist, V->users());
405         continue;
406       }
407 
408       if (auto *LI = dyn_cast<LoadInst>(V)) {
409         Value *Ptr = LI->getPointerOperand();
410         APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
411         Ptr =
412             Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
413                                                    /* AllowNonInbounds */ true);
414         assert(Ptr == &Arg && "Not constant offset from arg?");
415         LI->replaceAllUsesWith(OffsetToArg[Offset.getSExtValue()]);
416         DeadInsts.push_back(LI);
417         continue;
418       }
419 
420       llvm_unreachable("Unexpected user");
421     }
422 
423     for (Instruction *I : DeadInsts) {
424       I->replaceAllUsesWith(UndefValue::get(I->getType()));
425       I->eraseFromParent();
426     }
427   }
428 
429   return NF;
430 }
431 
432 /// Return true if we can prove that all callees pass in a valid pointer for the
433 /// specified function argument.
434 static bool allCallersPassValidPointerForArgument(Argument *Arg,
435                                                   Align NeededAlign,
436                                                   uint64_t NeededDerefBytes) {
437   Function *Callee = Arg->getParent();
438   const DataLayout &DL = Callee->getParent()->getDataLayout();
439   APInt Bytes(64, NeededDerefBytes);
440 
441   // Check if the argument itself is marked dereferenceable and aligned.
442   if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
443     return true;
444 
445   // Look at all call sites of the function.  At this point we know we only have
446   // direct callees.
447   return all_of(Callee->users(), [&](User *U) {
448     CallBase &CB = cast<CallBase>(*U);
449     return isDereferenceableAndAlignedPointer(
450         CB.getArgOperand(Arg->getArgNo()), NeededAlign, Bytes, DL);
451   });
452 }
453 
454 /// Determine that this argument is safe to promote, and find the argument
455 /// parts it can be promoted into.
456 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
457                          unsigned MaxElements, bool IsRecursive,
458                          SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
459   // Quick exit for unused arguments
460   if (Arg->use_empty())
461     return true;
462 
463   // We can only promote this argument if all of the uses are loads at known
464   // offsets.
465   //
466   // Promoting the argument causes it to be loaded in the caller
467   // unconditionally. This is only safe if we can prove that either the load
468   // would have happened in the callee anyway (ie, there is a load in the entry
469   // block) or the pointer passed in at every call site is guaranteed to be
470   // valid.
471   // In the former case, invalid loads can happen, but would have happened
472   // anyway, in the latter case, invalid loads won't happen. This prevents us
473   // from introducing an invalid load that wouldn't have happened in the
474   // original code.
475 
476   SmallDenseMap<int64_t, ArgPart, 4> ArgParts;
477   Align NeededAlign(1);
478   uint64_t NeededDerefBytes = 0;
479 
480   // Returns None if this load is not based on the argument. Return true if
481   // we can promote the load, false otherwise.
482   auto HandleLoad = [&](LoadInst *LI,
483                         bool GuaranteedToExecute) -> Optional<bool> {
484     // Don't promote volatile or atomic loads.
485     if (!LI->isSimple())
486       return false;
487 
488     Value *Ptr = LI->getPointerOperand();
489     APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
490     Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
491                                                  /* AllowNonInbounds */ true);
492     if (Ptr != Arg)
493       return None;
494 
495     if (Offset.getSignificantBits() >= 64)
496       return false;
497 
498     Type *Ty = LI->getType();
499     TypeSize Size = DL.getTypeStoreSize(Ty);
500     // Don't try to promote scalable types.
501     if (Size.isScalable())
502       return false;
503 
504     // If this is a recursive function and one of the types is a pointer,
505     // then promoting it might lead to recursive promotion.
506     if (IsRecursive && Ty->isPointerTy())
507       return false;
508 
509     int64_t Off = Offset.getSExtValue();
510     auto Pair = ArgParts.try_emplace(
511         Off, ArgPart{Ty, LI->getAlign(), GuaranteedToExecute ? LI : nullptr});
512     ArgPart &Part = Pair.first->second;
513     bool OffsetNotSeenBefore = Pair.second;
514 
515     // We limit promotion to only promoting up to a fixed number of elements of
516     // the aggregate.
517     if (MaxElements > 0 && ArgParts.size() >= MaxElements) {
518       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
519                         << "more than " << MaxElements << " parts\n");
520       return false;
521     }
522 
523     // For now, we only support loading one specific type at a given offset.
524     if (Part.Ty != Ty) {
525       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
526                         << "loaded via both " << *Part.Ty << " and " << *Ty
527                         << " at offset " << Off << "\n");
528       return false;
529     }
530 
531     // If this load is not guaranteed to execute and we haven't seen a load at
532     // this offset before (or it had lower alignment), then we need to remember
533     // that requirement.
534     // Note that skipping loads of previously seen offsets is only correct
535     // because we only allow a single type for a given offset, which also means
536     // that the number of accessed bytes will be the same.
537     if (!GuaranteedToExecute &&
538         (OffsetNotSeenBefore || Part.Alignment < LI->getAlign())) {
539       // We won't be able to prove dereferenceability for negative offsets.
540       if (Off < 0)
541         return false;
542 
543       // If the offset is not aligned, an aligned base pointer won't help.
544       if (!isAligned(LI->getAlign(), Off))
545         return false;
546 
547       NeededDerefBytes = std::max(NeededDerefBytes, Off + Size.getFixedValue());
548       NeededAlign = std::max(NeededAlign, LI->getAlign());
549     }
550 
551     Part.Alignment = std::max(Part.Alignment, LI->getAlign());
552     return true;
553   };
554 
555   // Look for loads that are guaranteed to execute on entry.
556   for (Instruction &I : Arg->getParent()->getEntryBlock()) {
557     if (LoadInst *LI = dyn_cast<LoadInst>(&I))
558       if (Optional<bool> Res = HandleLoad(LI, /* GuaranteedToExecute */ true))
559         if (!*Res)
560           return false;
561 
562     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
563       break;
564   }
565 
566   // Now look at all loads of the argument. Remember the load instructions
567   // for the aliasing check below.
568   SmallVector<Value *, 16> Worklist;
569   SmallPtrSet<Value *, 16> Visited;
570   SmallVector<LoadInst *, 16> Loads;
571   auto AppendUsers = [&](Value *V) {
572     for (User *U : V->users())
573       if (Visited.insert(U).second)
574         Worklist.push_back(U);
575   };
576   AppendUsers(Arg);
577   while (!Worklist.empty()) {
578     Value *V = Worklist.pop_back_val();
579     if (isa<BitCastInst>(V)) {
580       AppendUsers(V);
581       continue;
582     }
583 
584     if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
585       if (!GEP->hasAllConstantIndices())
586         return false;
587       AppendUsers(V);
588       continue;
589     }
590 
591     if (auto *LI = dyn_cast<LoadInst>(V)) {
592       if (!*HandleLoad(LI, /* GuaranteedToExecute */ false))
593         return false;
594       Loads.push_back(LI);
595       continue;
596     }
597 
598     // Unknown user.
599     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
600                       << "unknown user " << *V << "\n");
601     return false;
602   }
603 
604   if (NeededDerefBytes || NeededAlign > 1) {
605     // Try to prove a required deref / aligned requirement.
606     if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
607                                                NeededDerefBytes)) {
608       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
609                         << "not dereferenceable or aligned\n");
610       return false;
611     }
612   }
613 
614   if (ArgParts.empty())
615     return true; // No users, this is a dead argument.
616 
617   // Sort parts by offset.
618   append_range(ArgPartsVec, ArgParts);
619   sort(ArgPartsVec,
620        [](const auto &A, const auto &B) { return A.first < B.first; });
621 
622   // Make sure the parts are non-overlapping.
623   // TODO: As we're doing pure load promotion here, overlap should be fine from
624   // a correctness perspective. Profitability is less obvious though.
625   int64_t Offset = ArgPartsVec[0].first;
626   for (const auto &Pair : ArgPartsVec) {
627     if (Pair.first < Offset)
628       return false; // Overlap with previous part.
629 
630     Offset = Pair.first + DL.getTypeStoreSize(Pair.second.Ty);
631   }
632 
633   // Okay, now we know that the argument is only used by load instructions and
634   // it is safe to unconditionally perform all of them. Use alias analysis to
635   // check to see if the pointer is guaranteed to not be modified from entry of
636   // the function to each of the load instructions.
637 
638   // Because there could be several/many load instructions, remember which
639   // blocks we know to be transparent to the load.
640   df_iterator_default_set<BasicBlock *, 16> TranspBlocks;
641 
642   for (LoadInst *Load : Loads) {
643     // Check to see if the load is invalidated from the start of the block to
644     // the load itself.
645     BasicBlock *BB = Load->getParent();
646 
647     MemoryLocation Loc = MemoryLocation::get(Load);
648     if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod))
649       return false; // Pointer is invalidated!
650 
651     // Now check every path from the entry block to the load for transparency.
652     // To do this, we perform a depth first search on the inverse CFG from the
653     // loading block.
654     for (BasicBlock *P : predecessors(BB)) {
655       for (BasicBlock *TranspBB : inverse_depth_first_ext(P, TranspBlocks))
656         if (AAR.canBasicBlockModify(*TranspBB, Loc))
657           return false;
658     }
659   }
660 
661   // If the path from the entry of the function to each load is free of
662   // instructions that potentially invalidate the load, we can make the
663   // transformation!
664   return true;
665 }
666 
667 bool ArgumentPromotionPass::isDenselyPacked(Type *type, const DataLayout &DL) {
668   // There is no size information, so be conservative.
669   if (!type->isSized())
670     return false;
671 
672   // If the alloc size is not equal to the storage size, then there are padding
673   // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128.
674   if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type))
675     return false;
676 
677   // FIXME: This isn't the right way to check for padding in vectors with
678   // non-byte-size elements.
679   if (VectorType *seqTy = dyn_cast<VectorType>(type))
680     return isDenselyPacked(seqTy->getElementType(), DL);
681 
682   // For array types, check for padding within members.
683   if (ArrayType *seqTy = dyn_cast<ArrayType>(type))
684     return isDenselyPacked(seqTy->getElementType(), DL);
685 
686   if (!isa<StructType>(type))
687     return true;
688 
689   // Check for padding within and between elements of a struct.
690   StructType *StructTy = cast<StructType>(type);
691   const StructLayout *Layout = DL.getStructLayout(StructTy);
692   uint64_t StartPos = 0;
693   for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) {
694     Type *ElTy = StructTy->getElementType(i);
695     if (!isDenselyPacked(ElTy, DL))
696       return false;
697     if (StartPos != Layout->getElementOffsetInBits(i))
698       return false;
699     StartPos += DL.getTypeAllocSizeInBits(ElTy);
700   }
701 
702   return true;
703 }
704 
705 /// Checks if the padding bytes of an argument could be accessed.
706 static bool canPaddingBeAccessed(Argument *arg) {
707   assert(arg->hasByValAttr());
708 
709   // Track all the pointers to the argument to make sure they are not captured.
710   SmallPtrSet<Value *, 16> PtrValues;
711   PtrValues.insert(arg);
712 
713   // Track all of the stores.
714   SmallVector<StoreInst *, 16> Stores;
715 
716   // Scan through the uses recursively to make sure the pointer is always used
717   // sanely.
718   SmallVector<Value *, 16> WorkList(arg->users());
719   while (!WorkList.empty()) {
720     Value *V = WorkList.pop_back_val();
721     if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) {
722       if (PtrValues.insert(V).second)
723         llvm::append_range(WorkList, V->users());
724     } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) {
725       Stores.push_back(Store);
726     } else if (!isa<LoadInst>(V)) {
727       return true;
728     }
729   }
730 
731   // Check to make sure the pointers aren't captured
732   for (StoreInst *Store : Stores)
733     if (PtrValues.count(Store->getValueOperand()))
734       return true;
735 
736   return false;
737 }
738 
739 /// Check if callers and callee agree on how promoted arguments would be
740 /// passed.
741 static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
742                                   const TargetTransformInfo &TTI) {
743   return all_of(F.uses(), [&](const Use &U) {
744     CallBase *CB = dyn_cast<CallBase>(U.getUser());
745     if (!CB)
746       return false;
747 
748     const Function *Caller = CB->getCaller();
749     const Function *Callee = CB->getCalledFunction();
750     return TTI.areTypesABICompatible(Caller, Callee, Types);
751   });
752 }
753 
754 /// PromoteArguments - This method checks the specified function to see if there
755 /// are any promotable arguments and if it is safe to promote the function (for
756 /// example, all callers are direct).  If safe to promote some arguments, it
757 /// calls the DoPromotion method.
758 static Function *
759 promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter,
760                  unsigned MaxElements,
761                  Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>>
762                      ReplaceCallSite,
763                  const TargetTransformInfo &TTI, bool IsRecursive) {
764   // Don't perform argument promotion for naked functions; otherwise we can end
765   // up removing parameters that are seemingly 'not used' as they are referred
766   // to in the assembly.
767   if(F->hasFnAttribute(Attribute::Naked))
768     return nullptr;
769 
770   // Make sure that it is local to this module.
771   if (!F->hasLocalLinkage())
772     return nullptr;
773 
774   // Don't promote arguments for variadic functions. Adding, removing, or
775   // changing non-pack parameters can change the classification of pack
776   // parameters. Frontends encode that classification at the call site in the
777   // IR, while in the callee the classification is determined dynamically based
778   // on the number of registers consumed so far.
779   if (F->isVarArg())
780     return nullptr;
781 
782   // Don't transform functions that receive inallocas, as the transformation may
783   // not be safe depending on calling convention.
784   if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca))
785     return nullptr;
786 
787   // First check: see if there are any pointer arguments!  If not, quick exit.
788   SmallVector<Argument *, 16> PointerArgs;
789   for (Argument &I : F->args())
790     if (I.getType()->isPointerTy())
791       PointerArgs.push_back(&I);
792   if (PointerArgs.empty())
793     return nullptr;
794 
795   // Second check: make sure that all callers are direct callers.  We can't
796   // transform functions that have indirect callers.  Also see if the function
797   // is self-recursive.
798   for (Use &U : F->uses()) {
799     CallBase *CB = dyn_cast<CallBase>(U.getUser());
800     // Must be a direct call.
801     if (CB == nullptr || !CB->isCallee(&U))
802       return nullptr;
803 
804     // Can't change signature of musttail callee
805     if (CB->isMustTailCall())
806       return nullptr;
807 
808     if (CB->getParent()->getParent() == F)
809       IsRecursive = true;
810   }
811 
812   // Can't change signature of musttail caller
813   // FIXME: Support promoting whole chain of musttail functions
814   for (BasicBlock &BB : *F)
815     if (BB.getTerminatingMustTailCall())
816       return nullptr;
817 
818   const DataLayout &DL = F->getParent()->getDataLayout();
819 
820   AAResults &AAR = AARGetter(*F);
821 
822   // Check to see which arguments are promotable.  If an argument is promotable,
823   // add it to ArgsToPromote.
824   DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote;
825   SmallPtrSet<Argument *, 8> ByValArgsToTransform;
826   for (Argument *PtrArg : PointerArgs) {
827     // Replace sret attribute with noalias. This reduces register pressure by
828     // avoiding a register copy.
829     if (PtrArg->hasStructRetAttr()) {
830       unsigned ArgNo = PtrArg->getArgNo();
831       F->removeParamAttr(ArgNo, Attribute::StructRet);
832       F->addParamAttr(ArgNo, Attribute::NoAlias);
833       for (Use &U : F->uses()) {
834         CallBase &CB = cast<CallBase>(*U.getUser());
835         CB.removeParamAttr(ArgNo, Attribute::StructRet);
836         CB.addParamAttr(ArgNo, Attribute::NoAlias);
837       }
838     }
839 
840     // If this is a byval argument, and if the aggregate type is small, just
841     // pass the elements, which is always safe, if the passed value is densely
842     // packed or if we can prove the padding bytes are never accessed.
843     //
844     // Only handle arguments with specified alignment; if it's unspecified, the
845     // actual alignment of the argument is target-specific.
846     Type *ByValTy = PtrArg->getParamByValType();
847     bool isSafeToPromote =
848         ByValTy && PtrArg->getParamAlign() &&
849         (ArgumentPromotionPass::isDenselyPacked(ByValTy, DL) ||
850          !canPaddingBeAccessed(PtrArg));
851     if (isSafeToPromote) {
852       if (StructType *STy = dyn_cast<StructType>(ByValTy)) {
853         if (MaxElements > 0 && STy->getNumElements() > MaxElements) {
854           LLVM_DEBUG(dbgs() << "argpromotion disable promoting argument '"
855                             << PtrArg->getName()
856                             << "' because it would require adding more"
857                             << " than " << MaxElements
858                             << " arguments to the function.\n");
859           continue;
860         }
861 
862         SmallVector<Type *, 4> Types;
863         append_range(Types, STy->elements());
864 
865         // If all the elements are single-value types, we can promote it.
866         bool AllSimple =
867             all_of(Types, [](Type *Ty) { return Ty->isSingleValueType(); });
868 
869         // Safe to transform, don't even bother trying to "promote" it.
870         // Passing the elements as a scalar will allow sroa to hack on
871         // the new alloca we introduce.
872         if (AllSimple && areTypesABICompatible(Types, *F, TTI)) {
873           ByValArgsToTransform.insert(PtrArg);
874           continue;
875         }
876       }
877     }
878 
879     // Otherwise, see if we can promote the pointer to its value.
880     SmallVector<OffsetAndArgPart, 4> ArgParts;
881     if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
882       SmallVector<Type *, 4> Types;
883       for (const auto &Pair : ArgParts)
884         Types.push_back(Pair.second.Ty);
885 
886       if (areTypesABICompatible(Types, *F, TTI))
887         ArgsToPromote.insert({PtrArg, std::move(ArgParts)});
888     }
889   }
890 
891   // No promotable pointer arguments.
892   if (ArgsToPromote.empty() && ByValArgsToTransform.empty())
893     return nullptr;
894 
895   return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite);
896 }
897 
898 PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C,
899                                              CGSCCAnalysisManager &AM,
900                                              LazyCallGraph &CG,
901                                              CGSCCUpdateResult &UR) {
902   bool Changed = false, LocalChange;
903 
904   // Iterate until we stop promoting from this SCC.
905   do {
906     LocalChange = false;
907 
908     FunctionAnalysisManager &FAM =
909         AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
910 
911     bool IsRecursive = C.size() > 1;
912     for (LazyCallGraph::Node &N : C) {
913       Function &OldF = N.getFunction();
914 
915       // FIXME: This lambda must only be used with this function. We should
916       // skip the lambda and just get the AA results directly.
917       auto AARGetter = [&](Function &F) -> AAResults & {
918         assert(&F == &OldF && "Called with an unexpected function!");
919         return FAM.getResult<AAManager>(F);
920       };
921 
922       const TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(OldF);
923       Function *NewF = promoteArguments(&OldF, AARGetter, MaxElements, None,
924                                         TTI, IsRecursive);
925       if (!NewF)
926         continue;
927       LocalChange = true;
928 
929       // Directly substitute the functions in the call graph. Note that this
930       // requires the old function to be completely dead and completely
931       // replaced by the new function. It does no call graph updates, it merely
932       // swaps out the particular function mapped to a particular node in the
933       // graph.
934       C.getOuterRefSCC().replaceNodeFunction(N, *NewF);
935       FAM.clear(OldF, OldF.getName());
936       OldF.eraseFromParent();
937 
938       PreservedAnalyses FuncPA;
939       FuncPA.preserveSet<CFGAnalyses>();
940       for (auto *U : NewF->users()) {
941         auto *UserF = cast<CallBase>(U)->getFunction();
942         FAM.invalidate(*UserF, FuncPA);
943       }
944     }
945 
946     Changed |= LocalChange;
947   } while (LocalChange);
948 
949   if (!Changed)
950     return PreservedAnalyses::all();
951 
952   PreservedAnalyses PA;
953   // We've cleared out analyses for deleted functions.
954   PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
955   // We've manually invalidated analyses for functions we've modified.
956   PA.preserveSet<AllAnalysesOn<Function>>();
957   return PA;
958 }
959 
960 namespace {
961 
962 /// ArgPromotion - The 'by reference' to 'by value' argument promotion pass.
963 struct ArgPromotion : public CallGraphSCCPass {
964   // Pass identification, replacement for typeid
965   static char ID;
966 
967   explicit ArgPromotion(unsigned MaxElements = 3)
968       : CallGraphSCCPass(ID), MaxElements(MaxElements) {
969     initializeArgPromotionPass(*PassRegistry::getPassRegistry());
970   }
971 
972   void getAnalysisUsage(AnalysisUsage &AU) const override {
973     AU.addRequired<AssumptionCacheTracker>();
974     AU.addRequired<TargetLibraryInfoWrapperPass>();
975     AU.addRequired<TargetTransformInfoWrapperPass>();
976     getAAResultsAnalysisUsage(AU);
977     CallGraphSCCPass::getAnalysisUsage(AU);
978   }
979 
980   bool runOnSCC(CallGraphSCC &SCC) override;
981 
982 private:
983   using llvm::Pass::doInitialization;
984 
985   bool doInitialization(CallGraph &CG) override;
986 
987   /// The maximum number of elements to expand, or 0 for unlimited.
988   unsigned MaxElements;
989 };
990 
991 } // end anonymous namespace
992 
993 char ArgPromotion::ID = 0;
994 
995 INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion",
996                       "Promote 'by reference' arguments to scalars", false,
997                       false)
998 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
999 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
1000 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1001 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
1002 INITIALIZE_PASS_END(ArgPromotion, "argpromotion",
1003                     "Promote 'by reference' arguments to scalars", false, false)
1004 
1005 Pass *llvm::createArgumentPromotionPass(unsigned MaxElements) {
1006   return new ArgPromotion(MaxElements);
1007 }
1008 
1009 bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) {
1010   if (skipSCC(SCC))
1011     return false;
1012 
1013   // Get the callgraph information that we need to update to reflect our
1014   // changes.
1015   CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
1016 
1017   LegacyAARGetter AARGetter(*this);
1018 
1019   bool Changed = false, LocalChange;
1020 
1021   // Iterate until we stop promoting from this SCC.
1022   do {
1023     LocalChange = false;
1024     // Attempt to promote arguments from all functions in this SCC.
1025     bool IsRecursive = SCC.size() > 1;
1026     for (CallGraphNode *OldNode : SCC) {
1027       Function *OldF = OldNode->getFunction();
1028       if (!OldF)
1029         continue;
1030 
1031       auto ReplaceCallSite = [&](CallBase &OldCS, CallBase &NewCS) {
1032         Function *Caller = OldCS.getParent()->getParent();
1033         CallGraphNode *NewCalleeNode =
1034             CG.getOrInsertFunction(NewCS.getCalledFunction());
1035         CallGraphNode *CallerNode = CG[Caller];
1036         CallerNode->replaceCallEdge(cast<CallBase>(OldCS),
1037                                     cast<CallBase>(NewCS), NewCalleeNode);
1038       };
1039 
1040       const TargetTransformInfo &TTI =
1041           getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*OldF);
1042       if (Function *NewF =
1043               promoteArguments(OldF, AARGetter, MaxElements, {ReplaceCallSite},
1044                                TTI, IsRecursive)) {
1045         LocalChange = true;
1046 
1047         // Update the call graph for the newly promoted function.
1048         CallGraphNode *NewNode = CG.getOrInsertFunction(NewF);
1049         NewNode->stealCalledFunctionsFrom(OldNode);
1050         if (OldNode->getNumReferences() == 0)
1051           delete CG.removeFunctionFromModule(OldNode);
1052         else
1053           OldF->setLinkage(Function::ExternalLinkage);
1054 
1055         // And updat ethe SCC we're iterating as well.
1056         SCC.ReplaceNode(OldNode, NewNode);
1057       }
1058     }
1059     // Remember that we changed something.
1060     Changed |= LocalChange;
1061   } while (LocalChange);
1062 
1063   return Changed;
1064 }
1065 
1066 bool ArgPromotion::doInitialization(CallGraph &CG) {
1067   return CallGraphSCCPass::doInitialization(CG);
1068 }
1069