1 //===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass promotes "by reference" arguments to be "by value" arguments.  In
10 // practice, this means looking for internal functions that have pointer
11 // arguments.  If it can prove, through the use of alias analysis, that an
12 // argument is *only* loaded, then it can pass the value into the function
13 // instead of the address of the value.  This can cause recursive simplification
14 // of code and lead to the elimination of allocas (especially in C++ template
15 // code like the STL).
16 //
17 // This pass also handles aggregate arguments that are passed into a function,
18 // scalarizing them if the elements of the aggregate are only loaded.  Note that
19 // by default it refuses to scalarize aggregates which would require passing in
20 // more than three operands to the function, because passing thousands of
21 // operands for a large array or structure is unprofitable! This limit can be
22 // configured or disabled, however.
23 //
24 // Note that this transformation could also be done for arguments that are only
25 // stored to (returning the value instead), but does not currently.  This case
26 // would be best handled when and if LLVM begins supporting multiple return
27 // values from functions.
28 //
29 //===----------------------------------------------------------------------===//
30 
31 #include "llvm/Transforms/IPO/ArgumentPromotion.h"
32 #include "llvm/ADT/DepthFirstIterator.h"
33 #include "llvm/ADT/None.h"
34 #include "llvm/ADT/Optional.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/SmallPtrSet.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/Statistic.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/Analysis/AssumptionCache.h"
42 #include "llvm/Analysis/BasicAliasAnalysis.h"
43 #include "llvm/Analysis/CGSCCPassManager.h"
44 #include "llvm/Analysis/CallGraph.h"
45 #include "llvm/Analysis/CallGraphSCCPass.h"
46 #include "llvm/Analysis/LazyCallGraph.h"
47 #include "llvm/Analysis/Loads.h"
48 #include "llvm/Analysis/MemoryLocation.h"
49 #include "llvm/Analysis/TargetLibraryInfo.h"
50 #include "llvm/Analysis/TargetTransformInfo.h"
51 #include "llvm/Analysis/ValueTracking.h"
52 #include "llvm/IR/Argument.h"
53 #include "llvm/IR/Attributes.h"
54 #include "llvm/IR/BasicBlock.h"
55 #include "llvm/IR/CFG.h"
56 #include "llvm/IR/Constants.h"
57 #include "llvm/IR/DataLayout.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/IRBuilder.h"
61 #include "llvm/IR/InstrTypes.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/Instructions.h"
64 #include "llvm/IR/Metadata.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/NoFolder.h"
67 #include "llvm/IR/PassManager.h"
68 #include "llvm/IR/Type.h"
69 #include "llvm/IR/Use.h"
70 #include "llvm/IR/User.h"
71 #include "llvm/IR/Value.h"
72 #include "llvm/InitializePasses.h"
73 #include "llvm/Pass.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/Debug.h"
76 #include "llvm/Support/raw_ostream.h"
77 #include "llvm/Transforms/IPO.h"
78 #include <algorithm>
79 #include <cassert>
80 #include <cstdint>
81 #include <utility>
82 #include <vector>
83 
84 using namespace llvm;
85 
86 #define DEBUG_TYPE "argpromotion"
87 
88 STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");
89 STATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted");
90 STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated");
91 
92 namespace {
93 
94 struct ArgPart {
95   Type *Ty;
96   Align Alignment;
97   /// A representative guaranteed-executed load instruction for use by
98   /// metadata transfer.
99   LoadInst *MustExecLoad;
100 };
101 
102 using OffsetAndArgPart = std::pair<int64_t, ArgPart>;
103 
104 } // end anonymous namespace
105 
106 static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL,
107                             Value *Ptr, Type *ResElemTy, int64_t Offset) {
108   // For non-opaque pointers, try to create a "nice" GEP if possible, otherwise
109   // fall back to an i8 GEP to a specific offset.
110   unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace();
111   APInt OrigOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset);
112   if (!Ptr->getType()->isOpaquePointerTy()) {
113     Type *OrigElemTy = Ptr->getType()->getNonOpaquePointerElementType();
114     if (OrigOffset == 0 && OrigElemTy == ResElemTy)
115       return Ptr;
116 
117     if (OrigElemTy->isSized()) {
118       APInt TmpOffset = OrigOffset;
119       Type *TmpTy = OrigElemTy;
120       SmallVector<APInt> IntIndices =
121           DL.getGEPIndicesForOffset(TmpTy, TmpOffset);
122       if (TmpOffset == 0) {
123         // Try to add trailing zero indices to reach the right type.
124         while (TmpTy != ResElemTy) {
125           Type *NextTy = GetElementPtrInst::getTypeAtIndex(TmpTy, (uint64_t)0);
126           if (!NextTy)
127             break;
128 
129           IntIndices.push_back(APInt::getZero(
130               isa<StructType>(TmpTy) ? 32 : OrigOffset.getBitWidth()));
131           TmpTy = NextTy;
132         }
133 
134         SmallVector<Value *> Indices;
135         for (const APInt &Index : IntIndices)
136           Indices.push_back(IRB.getInt(Index));
137 
138         if (OrigOffset != 0 || TmpTy == ResElemTy) {
139           Ptr = IRB.CreateGEP(OrigElemTy, Ptr, Indices);
140           return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
141         }
142       }
143     }
144   }
145 
146   if (OrigOffset != 0) {
147     Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy(AddrSpace));
148     Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(OrigOffset));
149   }
150   return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
151 }
152 
153 /// DoPromotion - This method actually performs the promotion of the specified
154 /// arguments, and returns the new function.  At this point, we know that it's
155 /// safe to do so.
156 static Function *doPromotion(
157     Function *F,
158     const DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> &ArgsToPromote,
159     SmallPtrSetImpl<Argument *> &ByValArgsToTransform,
160     Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>>
161         ReplaceCallSite) {
162   // Start by computing a new prototype for the function, which is the same as
163   // the old function, but has modified arguments.
164   FunctionType *FTy = F->getFunctionType();
165   std::vector<Type *> Params;
166 
167   // Attribute - Keep track of the parameter attributes for the arguments
168   // that we are *not* promoting. For the ones that we do promote, the parameter
169   // attributes are lost
170   SmallVector<AttributeSet, 8> ArgAttrVec;
171   AttributeList PAL = F->getAttributes();
172 
173   // First, determine the new argument list
174   unsigned ArgNo = 0;
175   for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
176        ++I, ++ArgNo) {
177     if (ByValArgsToTransform.count(&*I)) {
178       // Simple byval argument? Just add all the struct element types.
179       Type *AgTy = I->getParamByValType();
180       StructType *STy = cast<StructType>(AgTy);
181       llvm::append_range(Params, STy->elements());
182       ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(),
183                         AttributeSet());
184       ++NumByValArgsPromoted;
185     } else if (!ArgsToPromote.count(&*I)) {
186       // Unchanged argument
187       Params.push_back(I->getType());
188       ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo));
189     } else if (I->use_empty()) {
190       // Dead argument (which are always marked as promotable)
191       ++NumArgumentsDead;
192     } else {
193       const auto &ArgParts = ArgsToPromote.find(&*I)->second;
194       for (const auto &Pair : ArgParts) {
195         Params.push_back(Pair.second.Ty);
196         ArgAttrVec.push_back(AttributeSet());
197       }
198       ++NumArgumentsPromoted;
199     }
200   }
201 
202   Type *RetTy = FTy->getReturnType();
203 
204   // Construct the new function type using the new arguments.
205   FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
206 
207   // Create the new function body and insert it into the module.
208   Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(),
209                                   F->getName());
210   NF->copyAttributesFrom(F);
211   NF->copyMetadata(F, 0);
212 
213   // The new function will have the !dbg metadata copied from the original
214   // function. The original function may not be deleted, and dbg metadata need
215   // to be unique, so we need to drop it.
216   F->setSubprogram(nullptr);
217 
218   LLVM_DEBUG(dbgs() << "ARG PROMOTION:  Promoting to:" << *NF << "\n"
219                     << "From: " << *F);
220 
221   uint64_t LargestVectorWidth = 0;
222   for (auto *I : Params)
223     if (auto *VT = dyn_cast<llvm::VectorType>(I))
224       LargestVectorWidth = std::max(
225           LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinSize());
226 
227   // Recompute the parameter attributes list based on the new arguments for
228   // the function.
229   NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(),
230                                        PAL.getRetAttrs(), ArgAttrVec));
231   AttributeFuncs::updateMinLegalVectorWidthAttr(*NF, LargestVectorWidth);
232   ArgAttrVec.clear();
233 
234   F->getParent()->getFunctionList().insert(F->getIterator(), NF);
235   NF->takeName(F);
236 
237   // Loop over all of the callers of the function, transforming the call sites
238   // to pass in the loaded pointers.
239   //
240   SmallVector<Value *, 16> Args;
241   const DataLayout &DL = F->getParent()->getDataLayout();
242   while (!F->use_empty()) {
243     CallBase &CB = cast<CallBase>(*F->user_back());
244     assert(CB.getCalledFunction() == F);
245     const AttributeList &CallPAL = CB.getAttributes();
246     IRBuilder<NoFolder> IRB(&CB);
247 
248     // Loop over the operands, inserting GEP and loads in the caller as
249     // appropriate.
250     auto *AI = CB.arg_begin();
251     ArgNo = 0;
252     for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
253          ++I, ++AI, ++ArgNo)
254       if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) {
255         Args.push_back(*AI); // Unmodified argument
256         ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
257       } else if (ByValArgsToTransform.count(&*I)) {
258         // Emit a GEP and load for each element of the struct.
259         Type *AgTy = I->getParamByValType();
260         StructType *STy = cast<StructType>(AgTy);
261         Value *Idxs[2] = {
262             ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr};
263         const StructLayout *SL = DL.getStructLayout(STy);
264         Align StructAlign = *I->getParamAlign();
265         for (unsigned J = 0, Elems = STy->getNumElements(); J != Elems; ++J) {
266           Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), J);
267           auto *Idx =
268               IRB.CreateGEP(STy, *AI, Idxs, (*AI)->getName() + "." + Twine(J));
269           // TODO: Tell AA about the new values?
270           Align Alignment =
271               commonAlignment(StructAlign, SL->getElementOffset(J));
272           Args.push_back(IRB.CreateAlignedLoad(
273               STy->getElementType(J), Idx, Alignment, Idx->getName() + ".val"));
274           ArgAttrVec.push_back(AttributeSet());
275         }
276       } else if (!I->use_empty()) {
277         Value *V = *AI;
278         const auto &ArgParts = ArgsToPromote.find(&*I)->second;
279         for (const auto &Pair : ArgParts) {
280           LoadInst *LI = IRB.CreateAlignedLoad(
281               Pair.second.Ty,
282               createByteGEP(IRB, DL, V, Pair.second.Ty, Pair.first),
283               Pair.second.Alignment, V->getName() + ".val");
284           if (Pair.second.MustExecLoad) {
285             LI->setAAMetadata(Pair.second.MustExecLoad->getAAMetadata());
286             LI->copyMetadata(*Pair.second.MustExecLoad,
287                              {LLVMContext::MD_range, LLVMContext::MD_nonnull,
288                               LLVMContext::MD_dereferenceable,
289                               LLVMContext::MD_dereferenceable_or_null,
290                               LLVMContext::MD_align, LLVMContext::MD_noundef});
291           }
292           Args.push_back(LI);
293           ArgAttrVec.push_back(AttributeSet());
294         }
295       }
296 
297     // Push any varargs arguments on the list.
298     for (; AI != CB.arg_end(); ++AI, ++ArgNo) {
299       Args.push_back(*AI);
300       ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
301     }
302 
303     SmallVector<OperandBundleDef, 1> OpBundles;
304     CB.getOperandBundlesAsDefs(OpBundles);
305 
306     CallBase *NewCS = nullptr;
307     if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) {
308       NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
309                                  Args, OpBundles, "", &CB);
310     } else {
311       auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", &CB);
312       NewCall->setTailCallKind(cast<CallInst>(&CB)->getTailCallKind());
313       NewCS = NewCall;
314     }
315     NewCS->setCallingConv(CB.getCallingConv());
316     NewCS->setAttributes(AttributeList::get(F->getContext(),
317                                             CallPAL.getFnAttrs(),
318                                             CallPAL.getRetAttrs(), ArgAttrVec));
319     NewCS->copyMetadata(CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg});
320     Args.clear();
321     ArgAttrVec.clear();
322 
323     AttributeFuncs::updateMinLegalVectorWidthAttr(*CB.getCaller(),
324                                                   LargestVectorWidth);
325 
326     // Update the callgraph to know that the callsite has been transformed.
327     if (ReplaceCallSite)
328       (*ReplaceCallSite)(CB, *NewCS);
329 
330     if (!CB.use_empty()) {
331       CB.replaceAllUsesWith(NewCS);
332       NewCS->takeName(&CB);
333     }
334 
335     // Finally, remove the old call from the program, reducing the use-count of
336     // F.
337     CB.eraseFromParent();
338   }
339 
340   // Since we have now created the new function, splice the body of the old
341   // function right into the new function, leaving the old rotting hulk of the
342   // function empty.
343   NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
344 
345   // Loop over the argument list, transferring uses of the old arguments over to
346   // the new arguments, also transferring over the names as well.
347   Function::arg_iterator I2 = NF->arg_begin();
348   for (Argument &Arg : F->args()) {
349     if (!ArgsToPromote.count(&Arg) && !ByValArgsToTransform.count(&Arg)) {
350       // If this is an unmodified argument, move the name and users over to the
351       // new version.
352       Arg.replaceAllUsesWith(&*I2);
353       I2->takeName(&Arg);
354       ++I2;
355       continue;
356     }
357 
358     if (ByValArgsToTransform.count(&Arg)) {
359       // In the callee, we create an alloca, and store each of the new incoming
360       // arguments into the alloca.
361       Instruction *InsertPt = &NF->begin()->front();
362 
363       // Just add all the struct element types.
364       Type *AgTy = Arg.getParamByValType();
365       Align StructAlign = *Arg.getParamAlign();
366       Value *TheAlloca = new AllocaInst(AgTy, DL.getAllocaAddrSpace(), nullptr,
367                                         StructAlign, "", InsertPt);
368       StructType *STy = cast<StructType>(AgTy);
369       Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(F->getContext()), 0),
370                         nullptr};
371       const StructLayout *SL = DL.getStructLayout(STy);
372 
373       for (unsigned J = 0, Elems = STy->getNumElements(); J != Elems; ++J) {
374         Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), J);
375         Value *Idx = GetElementPtrInst::Create(
376             AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(J),
377             InsertPt);
378         I2->setName(Arg.getName() + "." + Twine(J));
379         Align Alignment = commonAlignment(StructAlign, SL->getElementOffset(J));
380         new StoreInst(&*I2++, Idx, false, Alignment, InsertPt);
381       }
382 
383       // Anything that used the arg should now use the alloca.
384       Arg.replaceAllUsesWith(TheAlloca);
385       TheAlloca->takeName(&Arg);
386       continue;
387     }
388 
389     // There potentially are metadata uses for things like llvm.dbg.value.
390     // Replace them with undef, after handling the other regular uses.
391     auto RauwUndefMetadata = make_scope_exit(
392         [&]() { Arg.replaceAllUsesWith(UndefValue::get(Arg.getType())); });
393 
394     if (Arg.use_empty())
395       continue;
396 
397     SmallDenseMap<int64_t, Argument *> OffsetToArg;
398     for (const auto &Pair : ArgsToPromote.find(&Arg)->second) {
399       Argument &NewArg = *I2++;
400       NewArg.setName(Arg.getName() + "." + Twine(Pair.first) + ".val");
401       OffsetToArg.insert({Pair.first, &NewArg});
402     }
403 
404     // Otherwise, if we promoted this argument, then all users are load
405     // instructions (with possible casts and GEPs in between).
406 
407     SmallVector<Value *, 16> Worklist;
408     SmallVector<Instruction *, 16> DeadInsts;
409     append_range(Worklist, Arg.users());
410     while (!Worklist.empty()) {
411       Value *V = Worklist.pop_back_val();
412       if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V)) {
413         DeadInsts.push_back(cast<Instruction>(V));
414         append_range(Worklist, V->users());
415         continue;
416       }
417 
418       if (auto *LI = dyn_cast<LoadInst>(V)) {
419         Value *Ptr = LI->getPointerOperand();
420         APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
421         Ptr =
422             Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
423                                                    /* AllowNonInbounds */ true);
424         assert(Ptr == &Arg && "Not constant offset from arg?");
425         LI->replaceAllUsesWith(OffsetToArg[Offset.getSExtValue()]);
426         DeadInsts.push_back(LI);
427         continue;
428       }
429 
430       llvm_unreachable("Unexpected user");
431     }
432 
433     for (Instruction *I : DeadInsts) {
434       I->replaceAllUsesWith(UndefValue::get(I->getType()));
435       I->eraseFromParent();
436     }
437   }
438 
439   return NF;
440 }
441 
442 /// Return true if we can prove that all callees pass in a valid pointer for the
443 /// specified function argument.
444 static bool allCallersPassValidPointerForArgument(Argument *Arg,
445                                                   Align NeededAlign,
446                                                   uint64_t NeededDerefBytes) {
447   Function *Callee = Arg->getParent();
448   const DataLayout &DL = Callee->getParent()->getDataLayout();
449   APInt Bytes(64, NeededDerefBytes);
450 
451   // Check if the argument itself is marked dereferenceable and aligned.
452   if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
453     return true;
454 
455   // Look at all call sites of the function.  At this point we know we only have
456   // direct callees.
457   return all_of(Callee->users(), [&](User *U) {
458     CallBase &CB = cast<CallBase>(*U);
459     return isDereferenceableAndAlignedPointer(
460         CB.getArgOperand(Arg->getArgNo()), NeededAlign, Bytes, DL);
461   });
462 }
463 
464 /// Determine that this argument is safe to promote, and find the argument
465 /// parts it can be promoted into.
466 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
467                          unsigned MaxElements, bool IsRecursive,
468                          SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
469   // Quick exit for unused arguments
470   if (Arg->use_empty())
471     return true;
472 
473   // We can only promote this argument if all of the uses are loads at known
474   // offsets.
475   //
476   // Promoting the argument causes it to be loaded in the caller
477   // unconditionally. This is only safe if we can prove that either the load
478   // would have happened in the callee anyway (ie, there is a load in the entry
479   // block) or the pointer passed in at every call site is guaranteed to be
480   // valid.
481   // In the former case, invalid loads can happen, but would have happened
482   // anyway, in the latter case, invalid loads won't happen. This prevents us
483   // from introducing an invalid load that wouldn't have happened in the
484   // original code.
485 
486   SmallDenseMap<int64_t, ArgPart, 4> ArgParts;
487   Align NeededAlign(1);
488   uint64_t NeededDerefBytes = 0;
489 
490   // Returns None if this load is not based on the argument. Return true if
491   // we can promote the load, false otherwise.
492   auto HandleLoad = [&](LoadInst *LI,
493                         bool GuaranteedToExecute) -> Optional<bool> {
494     // Don't promote volatile or atomic loads.
495     if (!LI->isSimple())
496       return false;
497 
498     Value *Ptr = LI->getPointerOperand();
499     APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
500     Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
501                                                  /* AllowNonInbounds */ true);
502     if (Ptr != Arg)
503       return None;
504 
505     if (Offset.getSignificantBits() >= 64)
506       return false;
507 
508     Type *Ty = LI->getType();
509     TypeSize Size = DL.getTypeStoreSize(Ty);
510     // Don't try to promote scalable types.
511     if (Size.isScalable())
512       return false;
513 
514     // If this is a recursive function and one of the types is a pointer,
515     // then promoting it might lead to recursive promotion.
516     if (IsRecursive && Ty->isPointerTy())
517       return false;
518 
519     int64_t Off = Offset.getSExtValue();
520     auto Pair = ArgParts.try_emplace(
521         Off, ArgPart{Ty, LI->getAlign(), GuaranteedToExecute ? LI : nullptr});
522     ArgPart &Part = Pair.first->second;
523     bool OffsetNotSeenBefore = Pair.second;
524 
525     // We limit promotion to only promoting up to a fixed number of elements of
526     // the aggregate.
527     if (MaxElements > 0 && ArgParts.size() > MaxElements) {
528       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
529                         << "more than " << MaxElements << " parts\n");
530       return false;
531     }
532 
533     // For now, we only support loading one specific type at a given offset.
534     if (Part.Ty != Ty) {
535       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
536                         << "loaded via both " << *Part.Ty << " and " << *Ty
537                         << " at offset " << Off << "\n");
538       return false;
539     }
540 
541     // If this load is not guaranteed to execute, and we haven't seen a load at
542     // this offset before (or it had lower alignment), then we need to remember
543     // that requirement.
544     // Note that skipping loads of previously seen offsets is only correct
545     // because we only allow a single type for a given offset, which also means
546     // that the number of accessed bytes will be the same.
547     if (!GuaranteedToExecute &&
548         (OffsetNotSeenBefore || Part.Alignment < LI->getAlign())) {
549       // We won't be able to prove dereferenceability for negative offsets.
550       if (Off < 0)
551         return false;
552 
553       // If the offset is not aligned, an aligned base pointer won't help.
554       if (!isAligned(LI->getAlign(), Off))
555         return false;
556 
557       NeededDerefBytes = std::max(NeededDerefBytes, Off + Size.getFixedValue());
558       NeededAlign = std::max(NeededAlign, LI->getAlign());
559     }
560 
561     Part.Alignment = std::max(Part.Alignment, LI->getAlign());
562     return true;
563   };
564 
565   // Look for loads that are guaranteed to execute on entry.
566   for (Instruction &I : Arg->getParent()->getEntryBlock()) {
567     if (LoadInst *LI = dyn_cast<LoadInst>(&I))
568       if (Optional<bool> Res = HandleLoad(LI, /* GuaranteedToExecute */ true))
569         if (!*Res)
570           return false;
571 
572     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
573       break;
574   }
575 
576   // Now look at all loads of the argument. Remember the load instructions
577   // for the aliasing check below.
578   SmallVector<Value *, 16> Worklist;
579   SmallPtrSet<Value *, 16> Visited;
580   SmallVector<LoadInst *, 16> Loads;
581   auto AppendUsers = [&](Value *V) {
582     for (User *U : V->users())
583       if (Visited.insert(U).second)
584         Worklist.push_back(U);
585   };
586   AppendUsers(Arg);
587   while (!Worklist.empty()) {
588     Value *V = Worklist.pop_back_val();
589     if (isa<BitCastInst>(V)) {
590       AppendUsers(V);
591       continue;
592     }
593 
594     if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
595       if (!GEP->hasAllConstantIndices())
596         return false;
597       AppendUsers(V);
598       continue;
599     }
600 
601     if (auto *LI = dyn_cast<LoadInst>(V)) {
602       if (!*HandleLoad(LI, /* GuaranteedToExecute */ false))
603         return false;
604       Loads.push_back(LI);
605       continue;
606     }
607 
608     // Unknown user.
609     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
610                       << "unknown user " << *V << "\n");
611     return false;
612   }
613 
614   if (NeededDerefBytes || NeededAlign > 1) {
615     // Try to prove a required deref / aligned requirement.
616     if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
617                                                NeededDerefBytes)) {
618       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
619                         << "not dereferenceable or aligned\n");
620       return false;
621     }
622   }
623 
624   if (ArgParts.empty())
625     return true; // No users, this is a dead argument.
626 
627   // Sort parts by offset.
628   append_range(ArgPartsVec, ArgParts);
629   sort(ArgPartsVec,
630        [](const auto &A, const auto &B) { return A.first < B.first; });
631 
632   // Make sure the parts are non-overlapping.
633   // TODO: As we're doing pure load promotion here, overlap should be fine from
634   // a correctness perspective. Profitability is less obvious though.
635   int64_t Offset = ArgPartsVec[0].first;
636   for (const auto &Pair : ArgPartsVec) {
637     if (Pair.first < Offset)
638       return false; // Overlap with previous part.
639 
640     Offset = Pair.first + DL.getTypeStoreSize(Pair.second.Ty);
641   }
642 
643   // Okay, now we know that the argument is only used by load instructions, and
644   // it is safe to unconditionally perform all of them. Use alias analysis to
645   // check to see if the pointer is guaranteed to not be modified from entry of
646   // the function to each of the load instructions.
647 
648   // Because there could be several/many load instructions, remember which
649   // blocks we know to be transparent to the load.
650   df_iterator_default_set<BasicBlock *, 16> TranspBlocks;
651 
652   for (LoadInst *Load : Loads) {
653     // Check to see if the load is invalidated from the start of the block to
654     // the load itself.
655     BasicBlock *BB = Load->getParent();
656 
657     MemoryLocation Loc = MemoryLocation::get(Load);
658     if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod))
659       return false; // Pointer is invalidated!
660 
661     // Now check every path from the entry block to the load for transparency.
662     // To do this, we perform a depth first search on the inverse CFG from the
663     // loading block.
664     for (BasicBlock *P : predecessors(BB)) {
665       for (BasicBlock *TranspBB : inverse_depth_first_ext(P, TranspBlocks))
666         if (AAR.canBasicBlockModify(*TranspBB, Loc))
667           return false;
668     }
669   }
670 
671   // If the path from the entry of the function to each load is free of
672   // instructions that potentially invalidate the load, we can make the
673   // transformation!
674   return true;
675 }
676 
677 bool ArgumentPromotionPass::isDenselyPacked(Type *Ty, const DataLayout &DL) {
678   // There is no size information, so be conservative.
679   if (!Ty->isSized())
680     return false;
681 
682   // If the alloc size is not equal to the storage size, then there are padding
683   // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128.
684   if (DL.getTypeSizeInBits(Ty) != DL.getTypeAllocSizeInBits(Ty))
685     return false;
686 
687   // FIXME: This isn't the right way to check for padding in vectors with
688   // non-byte-size elements.
689   if (VectorType *SeqTy = dyn_cast<VectorType>(Ty))
690     return isDenselyPacked(SeqTy->getElementType(), DL);
691 
692   // For array types, check for padding within members.
693   if (ArrayType *SeqTy = dyn_cast<ArrayType>(Ty))
694     return isDenselyPacked(SeqTy->getElementType(), DL);
695 
696   if (!isa<StructType>(Ty))
697     return true;
698 
699   // Check for padding within and between elements of a struct.
700   StructType *StructTy = cast<StructType>(Ty);
701   const StructLayout *Layout = DL.getStructLayout(StructTy);
702   uint64_t StartPos = 0;
703   for (unsigned I = 0, E = StructTy->getNumElements(); I < E; ++I) {
704     Type *ElTy = StructTy->getElementType(I);
705     if (!isDenselyPacked(ElTy, DL))
706       return false;
707     if (StartPos != Layout->getElementOffsetInBits(I))
708       return false;
709     StartPos += DL.getTypeAllocSizeInBits(ElTy);
710   }
711 
712   return true;
713 }
714 
715 /// Checks if the padding bytes of an argument could be accessed.
716 static bool canPaddingBeAccessed(Argument *Arg) {
717   assert(Arg->hasByValAttr());
718 
719   // Track all the pointers to the argument to make sure they are not captured.
720   SmallPtrSet<Value *, 16> PtrValues;
721   PtrValues.insert(Arg);
722 
723   // Track all of the stores.
724   SmallVector<StoreInst *, 16> Stores;
725 
726   // Scan through the uses recursively to make sure the pointer is always used
727   // sanely.
728   SmallVector<Value *, 16> WorkList(Arg->users());
729   while (!WorkList.empty()) {
730     Value *V = WorkList.pop_back_val();
731     if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) {
732       if (PtrValues.insert(V).second)
733         append_range(WorkList, V->users());
734     } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) {
735       Stores.push_back(Store);
736     } else if (!isa<LoadInst>(V)) {
737       return true;
738     }
739   }
740 
741   // Check to make sure the pointers aren't captured
742   for (StoreInst *Store : Stores)
743     if (PtrValues.count(Store->getValueOperand()))
744       return true;
745 
746   return false;
747 }
748 
749 /// Check if callers and callee agree on how promoted arguments would be
750 /// passed.
751 static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
752                                   const TargetTransformInfo &TTI) {
753   return all_of(F.uses(), [&](const Use &U) {
754     CallBase *CB = dyn_cast<CallBase>(U.getUser());
755     if (!CB)
756       return false;
757 
758     const Function *Caller = CB->getCaller();
759     const Function *Callee = CB->getCalledFunction();
760     return TTI.areTypesABICompatible(Caller, Callee, Types);
761   });
762 }
763 
764 /// PromoteArguments - This method checks the specified function to see if there
765 /// are any promotable arguments and if it is safe to promote the function (for
766 /// example, all callers are direct).  If safe to promote some arguments, it
767 /// calls the DoPromotion method.
768 static Function *
769 promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter,
770                  unsigned MaxElements,
771                  Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>>
772                      ReplaceCallSite,
773                  const TargetTransformInfo &TTI, bool IsRecursive) {
774   // Don't perform argument promotion for naked functions; otherwise we can end
775   // up removing parameters that are seemingly 'not used' as they are referred
776   // to in the assembly.
777   if(F->hasFnAttribute(Attribute::Naked))
778     return nullptr;
779 
780   // Make sure that it is local to this module.
781   if (!F->hasLocalLinkage())
782     return nullptr;
783 
784   // Don't promote arguments for variadic functions. Adding, removing, or
785   // changing non-pack parameters can change the classification of pack
786   // parameters. Frontends encode that classification at the call site in the
787   // IR, while in the callee the classification is determined dynamically based
788   // on the number of registers consumed so far.
789   if (F->isVarArg())
790     return nullptr;
791 
792   // Don't transform functions that receive inallocas, as the transformation may
793   // not be safe depending on calling convention.
794   if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca))
795     return nullptr;
796 
797   // First check: see if there are any pointer arguments!  If not, quick exit.
798   SmallVector<Argument *, 16> PointerArgs;
799   for (Argument &I : F->args())
800     if (I.getType()->isPointerTy())
801       PointerArgs.push_back(&I);
802   if (PointerArgs.empty())
803     return nullptr;
804 
805   // Second check: make sure that all callers are direct callers.  We can't
806   // transform functions that have indirect callers.  Also see if the function
807   // is self-recursive.
808   for (Use &U : F->uses()) {
809     CallBase *CB = dyn_cast<CallBase>(U.getUser());
810     // Must be a direct call.
811     if (CB == nullptr || !CB->isCallee(&U) ||
812         CB->getFunctionType() != F->getFunctionType())
813       return nullptr;
814 
815     // Can't change signature of musttail callee
816     if (CB->isMustTailCall())
817       return nullptr;
818 
819     if (CB->getFunction() == F)
820       IsRecursive = true;
821   }
822 
823   // Can't change signature of musttail caller
824   // FIXME: Support promoting whole chain of musttail functions
825   for (BasicBlock &BB : *F)
826     if (BB.getTerminatingMustTailCall())
827       return nullptr;
828 
829   const DataLayout &DL = F->getParent()->getDataLayout();
830 
831   AAResults &AAR = AARGetter(*F);
832 
833   // Check to see which arguments are promotable.  If an argument is promotable,
834   // add it to ArgsToPromote.
835   DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote;
836   SmallPtrSet<Argument *, 8> ByValArgsToTransform;
837   for (Argument *PtrArg : PointerArgs) {
838     // Replace sret attribute with noalias. This reduces register pressure by
839     // avoiding a register copy.
840     if (PtrArg->hasStructRetAttr()) {
841       unsigned ArgNo = PtrArg->getArgNo();
842       F->removeParamAttr(ArgNo, Attribute::StructRet);
843       F->addParamAttr(ArgNo, Attribute::NoAlias);
844       for (Use &U : F->uses()) {
845         CallBase &CB = cast<CallBase>(*U.getUser());
846         CB.removeParamAttr(ArgNo, Attribute::StructRet);
847         CB.addParamAttr(ArgNo, Attribute::NoAlias);
848       }
849     }
850 
851     // If we can promote the pointer to its value.
852     SmallVector<OffsetAndArgPart, 4> ArgParts;
853     if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
854       SmallVector<Type *, 4> Types;
855       for (const auto &Pair : ArgParts)
856         Types.push_back(Pair.second.Ty);
857 
858       if (areTypesABICompatible(Types, *F, TTI)) {
859         ArgsToPromote.insert({PtrArg, std::move(ArgParts)});
860         continue;
861       }
862     }
863 
864     // Otherwise, if this is a byval argument, and if the aggregate type is
865     // small, just pass the elements, which is always safe, if the passed value
866     // is densely packed or if we can prove the padding bytes are never
867     // accessed.
868     //
869     // Only handle arguments with specified alignment; if it's unspecified, the
870     // actual alignment of the argument is target-specific.
871     Type *ByValTy = PtrArg->getParamByValType();
872     bool IsSafeToPromote =
873         ByValTy && PtrArg->getParamAlign() &&
874         (ArgumentPromotionPass::isDenselyPacked(ByValTy, DL) ||
875          !canPaddingBeAccessed(PtrArg));
876     if (!IsSafeToPromote) {
877       LLVM_DEBUG(dbgs() << "ArgPromotion disables passing the elements of"
878                         << " the argument '" << PtrArg->getName()
879                         << "' because it is not safe.\n");
880       continue;
881     }
882     if (StructType *STy = dyn_cast<StructType>(ByValTy)) {
883       if (MaxElements > 0 && STy->getNumElements() > MaxElements) {
884         LLVM_DEBUG(dbgs() << "ArgPromotion disables passing the elements of"
885                           << " the argument '" << PtrArg->getName()
886                           << "' because it would require adding more"
887                           << " than " << MaxElements
888                           << " arguments to the function.\n");
889         continue;
890       }
891       SmallVector<Type *, 4> Types;
892       append_range(Types, STy->elements());
893 
894       // If all the elements are single-value types, we can promote it.
895       bool AllSimple =
896           all_of(Types, [](Type *Ty) { return Ty->isSingleValueType(); });
897 
898       // Safe to transform. Passing the elements as a scalar will allow sroa to
899       // hack on the new alloca we introduce.
900       if (AllSimple && areTypesABICompatible(Types, *F, TTI))
901         ByValArgsToTransform.insert(PtrArg);
902     }
903   }
904 
905   // No promotable pointer arguments.
906   if (ArgsToPromote.empty() && ByValArgsToTransform.empty())
907     return nullptr;
908 
909   return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite);
910 }
911 
912 PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C,
913                                              CGSCCAnalysisManager &AM,
914                                              LazyCallGraph &CG,
915                                              CGSCCUpdateResult &UR) {
916   bool Changed = false, LocalChange;
917 
918   // Iterate until we stop promoting from this SCC.
919   do {
920     LocalChange = false;
921 
922     FunctionAnalysisManager &FAM =
923         AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
924 
925     bool IsRecursive = C.size() > 1;
926     for (LazyCallGraph::Node &N : C) {
927       Function &OldF = N.getFunction();
928 
929       // FIXME: This lambda must only be used with this function. We should
930       // skip the lambda and just get the AA results directly.
931       auto AARGetter = [&](Function &F) -> AAResults & {
932         assert(&F == &OldF && "Called with an unexpected function!");
933         return FAM.getResult<AAManager>(F);
934       };
935 
936       const TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(OldF);
937       Function *NewF = promoteArguments(&OldF, AARGetter, MaxElements, None,
938                                         TTI, IsRecursive);
939       if (!NewF)
940         continue;
941       LocalChange = true;
942 
943       // Directly substitute the functions in the call graph. Note that this
944       // requires the old function to be completely dead and completely
945       // replaced by the new function. It does no call graph updates, it merely
946       // swaps out the particular function mapped to a particular node in the
947       // graph.
948       C.getOuterRefSCC().replaceNodeFunction(N, *NewF);
949       FAM.clear(OldF, OldF.getName());
950       OldF.eraseFromParent();
951 
952       PreservedAnalyses FuncPA;
953       FuncPA.preserveSet<CFGAnalyses>();
954       for (auto *U : NewF->users()) {
955         auto *UserF = cast<CallBase>(U)->getFunction();
956         FAM.invalidate(*UserF, FuncPA);
957       }
958     }
959 
960     Changed |= LocalChange;
961   } while (LocalChange);
962 
963   if (!Changed)
964     return PreservedAnalyses::all();
965 
966   PreservedAnalyses PA;
967   // We've cleared out analyses for deleted functions.
968   PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
969   // We've manually invalidated analyses for functions we've modified.
970   PA.preserveSet<AllAnalysesOn<Function>>();
971   return PA;
972 }
973 
974 namespace {
975 
976 /// ArgPromotion - The 'by reference' to 'by value' argument promotion pass.
977 struct ArgPromotion : public CallGraphSCCPass {
978   // Pass identification, replacement for typeid
979   static char ID;
980 
981   explicit ArgPromotion(unsigned MaxElements = 3)
982       : CallGraphSCCPass(ID), MaxElements(MaxElements) {
983     initializeArgPromotionPass(*PassRegistry::getPassRegistry());
984   }
985 
986   void getAnalysisUsage(AnalysisUsage &AU) const override {
987     AU.addRequired<AssumptionCacheTracker>();
988     AU.addRequired<TargetLibraryInfoWrapperPass>();
989     AU.addRequired<TargetTransformInfoWrapperPass>();
990     getAAResultsAnalysisUsage(AU);
991     CallGraphSCCPass::getAnalysisUsage(AU);
992   }
993 
994   bool runOnSCC(CallGraphSCC &SCC) override;
995 
996 private:
997   using llvm::Pass::doInitialization;
998 
999   bool doInitialization(CallGraph &CG) override;
1000 
1001   /// The maximum number of elements to expand, or 0 for unlimited.
1002   unsigned MaxElements;
1003 };
1004 
1005 } // end anonymous namespace
1006 
1007 char ArgPromotion::ID = 0;
1008 
1009 INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion",
1010                       "Promote 'by reference' arguments to scalars", false,
1011                       false)
1012 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
1013 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
1014 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1015 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
1016 INITIALIZE_PASS_END(ArgPromotion, "argpromotion",
1017                     "Promote 'by reference' arguments to scalars", false, false)
1018 
1019 Pass *llvm::createArgumentPromotionPass(unsigned MaxElements) {
1020   return new ArgPromotion(MaxElements);
1021 }
1022 
1023 bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) {
1024   if (skipSCC(SCC))
1025     return false;
1026 
1027   // Get the callgraph information that we need to update to reflect our
1028   // changes.
1029   CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
1030 
1031   LegacyAARGetter AARGetter(*this);
1032 
1033   bool Changed = false, LocalChange;
1034 
1035   // Iterate until we stop promoting from this SCC.
1036   do {
1037     LocalChange = false;
1038     // Attempt to promote arguments from all functions in this SCC.
1039     bool IsRecursive = SCC.size() > 1;
1040     for (CallGraphNode *OldNode : SCC) {
1041       Function *OldF = OldNode->getFunction();
1042       if (!OldF)
1043         continue;
1044 
1045       auto ReplaceCallSite = [&](CallBase &OldCS, CallBase &NewCS) {
1046         Function *Caller = OldCS.getParent()->getParent();
1047         CallGraphNode *NewCalleeNode =
1048             CG.getOrInsertFunction(NewCS.getCalledFunction());
1049         CallGraphNode *CallerNode = CG[Caller];
1050         CallerNode->replaceCallEdge(cast<CallBase>(OldCS),
1051                                     cast<CallBase>(NewCS), NewCalleeNode);
1052       };
1053 
1054       const TargetTransformInfo &TTI =
1055           getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*OldF);
1056       if (Function *NewF =
1057               promoteArguments(OldF, AARGetter, MaxElements, {ReplaceCallSite},
1058                                TTI, IsRecursive)) {
1059         LocalChange = true;
1060 
1061         // Update the call graph for the newly promoted function.
1062         CallGraphNode *NewNode = CG.getOrInsertFunction(NewF);
1063         NewNode->stealCalledFunctionsFrom(OldNode);
1064         if (OldNode->getNumReferences() == 0)
1065           delete CG.removeFunctionFromModule(OldNode);
1066         else
1067           OldF->setLinkage(Function::ExternalLinkage);
1068 
1069         // And update the SCC we're iterating as well.
1070         SCC.ReplaceNode(OldNode, NewNode);
1071       }
1072     }
1073     // Remember that we changed something.
1074     Changed |= LocalChange;
1075   } while (LocalChange);
1076 
1077   return Changed;
1078 }
1079 
1080 bool ArgPromotion::doInitialization(CallGraph &CG) {
1081   return CallGraphSCCPass::doInitialization(CG);
1082 }
1083