1 //===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "llvm/Analysis/StackSafetyAnalysis.h"
12 #include "llvm/ADT/APInt.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/ModuleSummaryAnalysis.h"
17 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
18 #include "llvm/Analysis/StackLifetime.h"
19 #include "llvm/IR/ConstantRange.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/GlobalValue.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <memory>
32 
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "stack-safety"
36 
37 STATISTIC(NumAllocaStackSafe, "Number of safe allocas");
38 STATISTIC(NumAllocaTotal, "Number of total allocas");
39 
40 STATISTIC(NumCombinedCalleeLookupTotal,
41           "Number of total callee lookups on combined index.");
42 STATISTIC(NumCombinedCalleeLookupFailed,
43           "Number of failed callee lookups on combined index.");
44 STATISTIC(NumModuleCalleeLookupTotal,
45           "Number of total callee lookups on module index.");
46 STATISTIC(NumModuleCalleeLookupFailed,
47           "Number of failed callee lookups on module index.");
48 STATISTIC(NumCombinedParamAccessesBefore,
49           "Number of total param accesses before generateParamAccessSummary.");
50 STATISTIC(NumCombinedParamAccessesAfter,
51           "Number of total param accesses after generateParamAccessSummary.");
52 STATISTIC(NumCombinedDataFlowNodes,
53           "Number of total nodes in combined index for dataflow processing.");
54 
55 static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations",
56                                              cl::init(20), cl::Hidden);
57 
58 static cl::opt<bool> StackSafetyPrint("stack-safety-print", cl::init(false),
59                                       cl::Hidden);
60 
61 static cl::opt<bool> StackSafetyRun("stack-safety-run", cl::init(false),
62                                     cl::Hidden);
63 
64 namespace {
65 
66 // Check if we should bailout for such ranges.
67 bool isUnsafe(const ConstantRange &R) {
68   return R.isEmptySet() || R.isFullSet() || R.isUpperSignWrapped();
69 }
70 
71 ConstantRange addOverflowNever(const ConstantRange &L, const ConstantRange &R) {
72   assert(!L.isSignWrappedSet());
73   assert(!R.isSignWrappedSet());
74   if (L.signedAddMayOverflow(R) !=
75       ConstantRange::OverflowResult::NeverOverflows)
76     return ConstantRange::getFull(L.getBitWidth());
77   ConstantRange Result = L.add(R);
78   assert(!Result.isSignWrappedSet());
79   return Result;
80 }
81 
82 ConstantRange unionNoWrap(const ConstantRange &L, const ConstantRange &R) {
83   assert(!L.isSignWrappedSet());
84   assert(!R.isSignWrappedSet());
85   auto Result = L.unionWith(R);
86   // Two non-wrapped sets can produce wrapped.
87   if (Result.isSignWrappedSet())
88     Result = ConstantRange::getFull(Result.getBitWidth());
89   return Result;
90 }
91 
92 /// Describes use of address in as a function call argument.
93 template <typename CalleeTy> struct CallInfo {
94   /// Function being called.
95   const CalleeTy *Callee = nullptr;
96   /// Index of argument which pass address.
97   size_t ParamNo = 0;
98   // Offset range of address from base address (alloca or calling function
99   // argument).
100   // Range should never set to empty-set, that is an invalid access range
101   // that can cause empty-set to be propagated with ConstantRange::add
102   ConstantRange Offset;
103   CallInfo(const CalleeTy *Callee, size_t ParamNo, const ConstantRange &Offset)
104       : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {}
105 };
106 
107 template <typename CalleeTy>
108 raw_ostream &operator<<(raw_ostream &OS, const CallInfo<CalleeTy> &P) {
109   return OS << "@" << P.Callee->getName() << "(arg" << P.ParamNo << ", "
110             << P.Offset << ")";
111 }
112 
113 /// Describe uses of address (alloca or parameter) inside of the function.
114 template <typename CalleeTy> struct UseInfo {
115   // Access range if the address (alloca or parameters).
116   // It is allowed to be empty-set when there are no known accesses.
117   ConstantRange Range;
118 
119   // List of calls which pass address as an argument.
120   SmallVector<CallInfo<CalleeTy>, 4> Calls;
121 
122   UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
123 
124   void updateRange(const ConstantRange &R) { Range = unionNoWrap(Range, R); }
125 };
126 
127 template <typename CalleeTy>
128 raw_ostream &operator<<(raw_ostream &OS, const UseInfo<CalleeTy> &U) {
129   OS << U.Range;
130   for (auto &Call : U.Calls)
131     OS << ", " << Call;
132   return OS;
133 }
134 
135 /// Calculate the allocation size of a given alloca. Returns empty range
136 // in case of confution.
137 ConstantRange getStaticAllocaSizeRange(const AllocaInst &AI) {
138   const DataLayout &DL = AI.getModule()->getDataLayout();
139   TypeSize TS = DL.getTypeAllocSize(AI.getAllocatedType());
140   unsigned PointerSize = DL.getMaxPointerSizeInBits();
141   // Fallback to empty range for alloca size.
142   ConstantRange R = ConstantRange::getEmpty(PointerSize);
143   if (TS.isScalable())
144     return R;
145   APInt APSize(PointerSize, TS.getFixedSize(), true);
146   if (APSize.isNonPositive())
147     return R;
148   if (AI.isArrayAllocation()) {
149     const auto *C = dyn_cast<ConstantInt>(AI.getArraySize());
150     if (!C)
151       return R;
152     bool Overflow = false;
153     APInt Mul = C->getValue();
154     if (Mul.isNonPositive())
155       return R;
156     Mul = Mul.sextOrTrunc(PointerSize);
157     APSize = APSize.smul_ov(Mul, Overflow);
158     if (Overflow)
159       return R;
160   }
161   R = ConstantRange(APInt::getNullValue(PointerSize), APSize);
162   assert(!isUnsafe(R));
163   return R;
164 }
165 
166 template <typename CalleeTy> struct FunctionInfo {
167   std::map<const AllocaInst *, UseInfo<CalleeTy>> Allocas;
168   std::map<uint32_t, UseInfo<CalleeTy>> Params;
169   // TODO: describe return value as depending on one or more of its arguments.
170 
171   // StackSafetyDataFlowAnalysis counter stored here for faster access.
172   int UpdateCount = 0;
173 
174   void print(raw_ostream &O, StringRef Name, const Function *F) const {
175     // TODO: Consider different printout format after
176     // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then.
177     O << "  @" << Name << ((F && F->isDSOLocal()) ? "" : " dso_preemptable")
178       << ((F && F->isInterposable()) ? " interposable" : "") << "\n";
179 
180     O << "    args uses:\n";
181     for (auto &KV : Params) {
182       O << "      ";
183       if (F)
184         O << F->getArg(KV.first)->getName();
185       else
186         O << formatv("arg{0}", KV.first);
187       O << "[]: " << KV.second << "\n";
188     }
189 
190     O << "    allocas uses:\n";
191     if (F) {
192       for (auto &I : instructions(F)) {
193         if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
194           auto &AS = Allocas.find(AI)->second;
195           O << "      " << AI->getName() << "["
196             << getStaticAllocaSizeRange(*AI).getUpper() << "]: " << AS << "\n";
197         }
198       }
199     } else {
200       assert(Allocas.empty());
201     }
202     O << "\n";
203   }
204 };
205 
206 using GVToSSI = std::map<const GlobalValue *, FunctionInfo<GlobalValue>>;
207 
208 } // namespace
209 
210 struct StackSafetyInfo::InfoTy {
211   FunctionInfo<GlobalValue> Info;
212 };
213 
214 struct StackSafetyGlobalInfo::InfoTy {
215   GVToSSI Info;
216   SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
217 };
218 
219 namespace {
220 
221 class StackSafetyLocalAnalysis {
222   Function &F;
223   const DataLayout &DL;
224   ScalarEvolution &SE;
225   unsigned PointerSize = 0;
226 
227   const ConstantRange UnknownRange;
228 
229   ConstantRange offsetFrom(Value *Addr, Value *Base);
230   ConstantRange getAccessRange(Value *Addr, Value *Base,
231                                const ConstantRange &SizeRange);
232   ConstantRange getAccessRange(Value *Addr, Value *Base, TypeSize Size);
233   ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U,
234                                            Value *Base);
235 
236   bool analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS,
237                       const StackLifetime &SL);
238 
239 public:
240   StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
241       : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
242         PointerSize(DL.getPointerSizeInBits()),
243         UnknownRange(PointerSize, true) {}
244 
245   // Run the transformation on the associated function.
246   FunctionInfo<GlobalValue> run();
247 };
248 
249 ConstantRange StackSafetyLocalAnalysis::offsetFrom(Value *Addr, Value *Base) {
250   if (!SE.isSCEVable(Addr->getType()) || !SE.isSCEVable(Base->getType()))
251     return UnknownRange;
252 
253   auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
254   const SCEV *AddrExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Addr), PtrTy);
255   const SCEV *BaseExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Base), PtrTy);
256   const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
257 
258   ConstantRange Offset = SE.getSignedRange(Diff);
259   if (isUnsafe(Offset))
260     return UnknownRange;
261   return Offset.sextOrTrunc(PointerSize);
262 }
263 
264 ConstantRange
265 StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
266                                          const ConstantRange &SizeRange) {
267   // Zero-size loads and stores do not access memory.
268   if (SizeRange.isEmptySet())
269     return ConstantRange::getEmpty(PointerSize);
270   assert(!isUnsafe(SizeRange));
271 
272   ConstantRange Offsets = offsetFrom(Addr, Base);
273   if (isUnsafe(Offsets))
274     return UnknownRange;
275 
276   Offsets = addOverflowNever(Offsets, SizeRange);
277   if (isUnsafe(Offsets))
278     return UnknownRange;
279   return Offsets;
280 }
281 
282 ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
283                                                        TypeSize Size) {
284   if (Size.isScalable())
285     return UnknownRange;
286   APInt APSize(PointerSize, Size.getFixedSize(), true);
287   if (APSize.isNegative())
288     return UnknownRange;
289   return getAccessRange(
290       Addr, Base, ConstantRange(APInt::getNullValue(PointerSize), APSize));
291 }
292 
293 ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
294     const MemIntrinsic *MI, const Use &U, Value *Base) {
295   if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
296     if (MTI->getRawSource() != U && MTI->getRawDest() != U)
297       return ConstantRange::getEmpty(PointerSize);
298   } else {
299     if (MI->getRawDest() != U)
300       return ConstantRange::getEmpty(PointerSize);
301   }
302 
303   auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
304   if (!SE.isSCEVable(MI->getLength()->getType()))
305     return UnknownRange;
306 
307   const SCEV *Expr =
308       SE.getTruncateOrZeroExtend(SE.getSCEV(MI->getLength()), CalculationTy);
309   ConstantRange Sizes = SE.getSignedRange(Expr);
310   if (Sizes.getUpper().isNegative() || isUnsafe(Sizes))
311     return UnknownRange;
312   Sizes = Sizes.sextOrTrunc(PointerSize);
313   ConstantRange SizeRange(APInt::getNullValue(PointerSize),
314                           Sizes.getUpper() - 1);
315   return getAccessRange(U, Base, SizeRange);
316 }
317 
318 /// The function analyzes all local uses of Ptr (alloca or argument) and
319 /// calculates local access range and all function calls where it was used.
320 bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
321                                               UseInfo<GlobalValue> &US,
322                                               const StackLifetime &SL) {
323   SmallPtrSet<const Value *, 16> Visited;
324   SmallVector<const Value *, 8> WorkList;
325   WorkList.push_back(Ptr);
326   const AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
327 
328   // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
329   while (!WorkList.empty()) {
330     const Value *V = WorkList.pop_back_val();
331     for (const Use &UI : V->uses()) {
332       const auto *I = cast<Instruction>(UI.getUser());
333       if (!SL.isReachable(I))
334         continue;
335 
336       assert(V == UI.get());
337 
338       switch (I->getOpcode()) {
339       case Instruction::Load: {
340         if (AI && !SL.isAliveAfter(AI, I)) {
341           US.updateRange(UnknownRange);
342           return false;
343         }
344         US.updateRange(
345             getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
346         break;
347       }
348 
349       case Instruction::VAArg:
350         // "va-arg" from a pointer is safe.
351         break;
352       case Instruction::Store: {
353         if (V == I->getOperand(0)) {
354           // Stored the pointer - conservatively assume it may be unsafe.
355           US.updateRange(UnknownRange);
356           return false;
357         }
358         if (AI && !SL.isAliveAfter(AI, I)) {
359           US.updateRange(UnknownRange);
360           return false;
361         }
362         US.updateRange(getAccessRange(
363             UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
364         break;
365       }
366 
367       case Instruction::Ret:
368         // Information leak.
369         // FIXME: Process parameters correctly. This is a leak only if we return
370         // alloca.
371         US.updateRange(UnknownRange);
372         return false;
373 
374       case Instruction::Call:
375       case Instruction::Invoke: {
376         if (I->isLifetimeStartOrEnd())
377           break;
378 
379         if (AI && !SL.isAliveAfter(AI, I)) {
380           US.updateRange(UnknownRange);
381           return false;
382         }
383 
384         if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
385           US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr));
386           break;
387         }
388 
389         const auto &CB = cast<CallBase>(*I);
390         if (!CB.isArgOperand(&UI)) {
391           US.updateRange(UnknownRange);
392           return false;
393         }
394 
395         unsigned ArgNo = CB.getArgOperandNo(&UI);
396         if (CB.isByValArgument(ArgNo)) {
397           US.updateRange(getAccessRange(
398               UI, Ptr, DL.getTypeStoreSize(CB.getParamByValType(ArgNo))));
399           break;
400         }
401 
402         // FIXME: consult devirt?
403         // Do not follow aliases, otherwise we could inadvertently follow
404         // dso_preemptable aliases or aliases with interposable linkage.
405         const GlobalValue *Callee =
406             dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
407         if (!Callee) {
408           US.updateRange(UnknownRange);
409           return false;
410         }
411 
412         assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee));
413         US.Calls.emplace_back(Callee, ArgNo, offsetFrom(UI, Ptr));
414         break;
415       }
416 
417       default:
418         if (Visited.insert(I).second)
419           WorkList.push_back(cast<const Instruction>(I));
420       }
421     }
422   }
423 
424   return true;
425 }
426 
427 FunctionInfo<GlobalValue> StackSafetyLocalAnalysis::run() {
428   FunctionInfo<GlobalValue> Info;
429   assert(!F.isDeclaration() &&
430          "Can't run StackSafety on a function declaration");
431 
432   LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n");
433 
434   SmallVector<AllocaInst *, 64> Allocas;
435   for (auto &I : instructions(F))
436     if (auto *AI = dyn_cast<AllocaInst>(&I))
437       Allocas.push_back(AI);
438   StackLifetime SL(F, Allocas, StackLifetime::LivenessType::Must);
439   SL.run();
440 
441   for (auto *AI : Allocas) {
442     auto &UI = Info.Allocas.emplace(AI, PointerSize).first->second;
443     analyzeAllUses(AI, UI, SL);
444   }
445 
446   for (Argument &A : make_range(F.arg_begin(), F.arg_end())) {
447     // Non pointers and bypass arguments are not going to be used in any global
448     // processing.
449     if (A.getType()->isPointerTy() && !A.hasByValAttr()) {
450       auto &UI = Info.Params.emplace(A.getArgNo(), PointerSize).first->second;
451       analyzeAllUses(&A, UI, SL);
452     }
453   }
454 
455   LLVM_DEBUG(Info.print(dbgs(), F.getName(), &F));
456   LLVM_DEBUG(dbgs() << "[StackSafety] done\n");
457   return Info;
458 }
459 
460 template <typename CalleeTy> class StackSafetyDataFlowAnalysis {
461   using FunctionMap = std::map<const CalleeTy *, FunctionInfo<CalleeTy>>;
462 
463   FunctionMap Functions;
464   const ConstantRange UnknownRange;
465 
466   // Callee-to-Caller multimap.
467   DenseMap<const CalleeTy *, SmallVector<const CalleeTy *, 4>> Callers;
468   SetVector<const CalleeTy *> WorkList;
469 
470   bool updateOneUse(UseInfo<CalleeTy> &US, bool UpdateToFullSet);
471   void updateOneNode(const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS);
472   void updateOneNode(const CalleeTy *Callee) {
473     updateOneNode(Callee, Functions.find(Callee)->second);
474   }
475   void updateAllNodes() {
476     for (auto &F : Functions)
477       updateOneNode(F.first, F.second);
478   }
479   void runDataFlow();
480 #ifndef NDEBUG
481   void verifyFixedPoint();
482 #endif
483 
484 public:
485   StackSafetyDataFlowAnalysis(uint32_t PointerBitWidth, FunctionMap Functions)
486       : Functions(std::move(Functions)),
487         UnknownRange(ConstantRange::getFull(PointerBitWidth)) {}
488 
489   const FunctionMap &run();
490 
491   ConstantRange getArgumentAccessRange(const CalleeTy *Callee, unsigned ParamNo,
492                                        const ConstantRange &Offsets) const;
493 };
494 
495 template <typename CalleeTy>
496 ConstantRange StackSafetyDataFlowAnalysis<CalleeTy>::getArgumentAccessRange(
497     const CalleeTy *Callee, unsigned ParamNo,
498     const ConstantRange &Offsets) const {
499   auto FnIt = Functions.find(Callee);
500   // Unknown callee (outside of LTO domain or an indirect call).
501   if (FnIt == Functions.end())
502     return UnknownRange;
503   auto &FS = FnIt->second;
504   auto ParamIt = FS.Params.find(ParamNo);
505   if (ParamIt == FS.Params.end())
506     return UnknownRange;
507   auto &Access = ParamIt->second.Range;
508   if (Access.isEmptySet())
509     return Access;
510   if (Access.isFullSet())
511     return UnknownRange;
512   return addOverflowNever(Access, Offsets);
513 }
514 
515 template <typename CalleeTy>
516 bool StackSafetyDataFlowAnalysis<CalleeTy>::updateOneUse(UseInfo<CalleeTy> &US,
517                                                          bool UpdateToFullSet) {
518   bool Changed = false;
519   for (auto &CS : US.Calls) {
520     assert(!CS.Offset.isEmptySet() &&
521            "Param range can't be empty-set, invalid offset range");
522 
523     ConstantRange CalleeRange =
524         getArgumentAccessRange(CS.Callee, CS.ParamNo, CS.Offset);
525     if (!US.Range.contains(CalleeRange)) {
526       Changed = true;
527       if (UpdateToFullSet)
528         US.Range = UnknownRange;
529       else
530         US.updateRange(CalleeRange);
531     }
532   }
533   return Changed;
534 }
535 
536 template <typename CalleeTy>
537 void StackSafetyDataFlowAnalysis<CalleeTy>::updateOneNode(
538     const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS) {
539   bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations;
540   bool Changed = false;
541   for (auto &KV : FS.Params)
542     Changed |= updateOneUse(KV.second, UpdateToFullSet);
543 
544   if (Changed) {
545     LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount
546                       << (UpdateToFullSet ? ", full-set" : "") << "] " << &FS
547                       << "\n");
548     // Callers of this function may need updating.
549     for (auto &CallerID : Callers[Callee])
550       WorkList.insert(CallerID);
551 
552     ++FS.UpdateCount;
553   }
554 }
555 
556 template <typename CalleeTy>
557 void StackSafetyDataFlowAnalysis<CalleeTy>::runDataFlow() {
558   SmallVector<const CalleeTy *, 16> Callees;
559   for (auto &F : Functions) {
560     Callees.clear();
561     auto &FS = F.second;
562     for (auto &KV : FS.Params)
563       for (auto &CS : KV.second.Calls)
564         Callees.push_back(CS.Callee);
565 
566     llvm::sort(Callees);
567     Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end());
568 
569     for (auto &Callee : Callees)
570       Callers[Callee].push_back(F.first);
571   }
572 
573   updateAllNodes();
574 
575   while (!WorkList.empty()) {
576     const CalleeTy *Callee = WorkList.back();
577     WorkList.pop_back();
578     updateOneNode(Callee);
579   }
580 }
581 
582 #ifndef NDEBUG
583 template <typename CalleeTy>
584 void StackSafetyDataFlowAnalysis<CalleeTy>::verifyFixedPoint() {
585   WorkList.clear();
586   updateAllNodes();
587   assert(WorkList.empty());
588 }
589 #endif
590 
591 template <typename CalleeTy>
592 const typename StackSafetyDataFlowAnalysis<CalleeTy>::FunctionMap &
593 StackSafetyDataFlowAnalysis<CalleeTy>::run() {
594   runDataFlow();
595   LLVM_DEBUG(verifyFixedPoint());
596   return Functions;
597 }
598 
599 FunctionSummary *resolveCallee(GlobalValueSummary *S) {
600   while (S) {
601     if (!S->isLive() || !S->isDSOLocal())
602       return nullptr;
603     if (FunctionSummary *FS = dyn_cast<FunctionSummary>(S))
604       return FS;
605     AliasSummary *AS = dyn_cast<AliasSummary>(S);
606     if (!AS || !AS->hasAliasee())
607       return nullptr;
608     S = AS->getBaseObject();
609     if (S == AS)
610       return nullptr;
611   }
612   return nullptr;
613 }
614 
615 const Function *findCalleeInModule(const GlobalValue *GV) {
616   while (GV) {
617     if (GV->isDeclaration() || GV->isInterposable() || !GV->isDSOLocal())
618       return nullptr;
619     if (const Function *F = dyn_cast<Function>(GV))
620       return F;
621     const GlobalAlias *A = dyn_cast<GlobalAlias>(GV);
622     if (!A)
623       return nullptr;
624     GV = A->getBaseObject();
625     if (GV == A)
626       return nullptr;
627   }
628   return nullptr;
629 }
630 
631 GlobalValueSummary *getGlobalValueSummary(const ModuleSummaryIndex *Index,
632                                           uint64_t ValueGUID) {
633   auto VI = Index->getValueInfo(ValueGUID);
634   if (!VI || VI.getSummaryList().empty())
635     return nullptr;
636   auto &Summary = VI.getSummaryList()[0];
637   return Summary.get();
638 }
639 
640 const ConstantRange *findParamAccess(const FunctionSummary &FS,
641                                      uint32_t ParamNo) {
642   assert(FS.isLive());
643   assert(FS.isDSOLocal());
644   for (auto &PS : FS.paramAccesses())
645     if (ParamNo == PS.ParamNo)
646       return &PS.Use;
647   return nullptr;
648 }
649 
650 void resolveAllCalls(UseInfo<GlobalValue> &Use,
651                      const ModuleSummaryIndex *Index) {
652   ConstantRange FullSet(Use.Range.getBitWidth(), true);
653   for (auto &C : Use.Calls) {
654     const Function *F = findCalleeInModule(C.Callee);
655     if (F) {
656       C.Callee = F;
657       continue;
658     }
659 
660     if (!Index)
661       return Use.updateRange(FullSet);
662     GlobalValueSummary *GVS = getGlobalValueSummary(Index, C.Callee->getGUID());
663 
664     FunctionSummary *FS = resolveCallee(GVS);
665     ++NumModuleCalleeLookupTotal;
666     if (!FS) {
667       ++NumModuleCalleeLookupFailed;
668       return Use.updateRange(FullSet);
669     }
670     const ConstantRange *Found = findParamAccess(*FS, C.ParamNo);
671     if (!Found || Found->isFullSet())
672       return Use.updateRange(FullSet);
673     ConstantRange Access = Found->sextOrTrunc(Use.Range.getBitWidth());
674     if (!Access.isEmptySet())
675       Use.updateRange(addOverflowNever(Access, C.Offset));
676     C.Callee = nullptr;
677   }
678 
679   Use.Calls.erase(std::remove_if(Use.Calls.begin(), Use.Calls.end(),
680                                  [](auto &T) { return !T.Callee; }),
681                   Use.Calls.end());
682 }
683 
684 GVToSSI createGlobalStackSafetyInfo(
685     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions,
686     const ModuleSummaryIndex *Index) {
687   GVToSSI SSI;
688   if (Functions.empty())
689     return SSI;
690 
691   // FIXME: Simplify printing and remove copying here.
692   auto Copy = Functions;
693 
694   for (auto &FnKV : Copy)
695     for (auto &KV : FnKV.second.Params)
696       resolveAllCalls(KV.second, Index);
697 
698   uint32_t PointerSize = Copy.begin()
699                              ->first->getParent()
700                              ->getDataLayout()
701                              .getMaxPointerSizeInBits();
702   StackSafetyDataFlowAnalysis<GlobalValue> SSDFA(PointerSize, std::move(Copy));
703 
704   for (auto &F : SSDFA.run()) {
705     auto FI = F.second;
706     auto &SrcF = Functions[F.first];
707     for (auto &KV : FI.Allocas) {
708       auto &A = KV.second;
709       resolveAllCalls(A, Index);
710       for (auto &C : A.Calls) {
711         A.updateRange(
712             SSDFA.getArgumentAccessRange(C.Callee, C.ParamNo, C.Offset));
713       }
714       // FIXME: This is needed only to preserve calls in print() results.
715       A.Calls = SrcF.Allocas.find(KV.first)->second.Calls;
716     }
717     for (auto &KV : FI.Params) {
718       auto &P = KV.second;
719       P.Calls = SrcF.Params.find(KV.first)->second.Calls;
720     }
721     SSI[F.first] = std::move(FI);
722   }
723 
724   return SSI;
725 }
726 
727 } // end anonymous namespace
728 
729 StackSafetyInfo::StackSafetyInfo() = default;
730 
731 StackSafetyInfo::StackSafetyInfo(Function *F,
732                                  std::function<ScalarEvolution &()> GetSE)
733     : F(F), GetSE(GetSE) {}
734 
735 StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default;
736 
737 StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default;
738 
739 StackSafetyInfo::~StackSafetyInfo() = default;
740 
741 const StackSafetyInfo::InfoTy &StackSafetyInfo::getInfo() const {
742   if (!Info) {
743     StackSafetyLocalAnalysis SSLA(*F, GetSE());
744     Info.reset(new InfoTy{SSLA.run()});
745   }
746   return *Info;
747 }
748 
749 void StackSafetyInfo::print(raw_ostream &O) const {
750   getInfo().Info.print(O, F->getName(), dyn_cast<Function>(F));
751 }
752 
753 const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
754   if (!Info) {
755     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions;
756     for (auto &F : M->functions()) {
757       if (!F.isDeclaration()) {
758         auto FI = GetSSI(F).getInfo().Info;
759         Functions.emplace(&F, std::move(FI));
760       }
761     }
762     Info.reset(new InfoTy{
763         createGlobalStackSafetyInfo(std::move(Functions), Index), {}});
764     for (auto &FnKV : Info->Info) {
765       for (auto &KV : FnKV.second.Allocas) {
766         ++NumAllocaTotal;
767         const AllocaInst *AI = KV.first;
768         if (getStaticAllocaSizeRange(*AI).contains(KV.second.Range)) {
769           Info->SafeAllocas.insert(AI);
770           ++NumAllocaStackSafe;
771         }
772       }
773     }
774     if (StackSafetyPrint)
775       print(errs());
776   }
777   return *Info;
778 }
779 
780 std::vector<FunctionSummary::ParamAccess>
781 StackSafetyInfo::getParamAccesses() const {
782   // Implementation transforms internal representation of parameter information
783   // into FunctionSummary format.
784   std::vector<FunctionSummary::ParamAccess> ParamAccesses;
785   for (const auto &KV : getInfo().Info.Params) {
786     auto &PS = KV.second;
787     // Parameter accessed by any or unknown offset, represented as FullSet by
788     // StackSafety, is handled as the parameter for which we have no
789     // StackSafety info at all. So drop it to reduce summary size.
790     if (PS.Range.isFullSet())
791       continue;
792 
793     ParamAccesses.emplace_back(KV.first, PS.Range);
794     FunctionSummary::ParamAccess &Param = ParamAccesses.back();
795 
796     Param.Calls.reserve(PS.Calls.size());
797     for (auto &C : PS.Calls) {
798       // Parameter forwarded into another function by any or unknown offset
799       // will make ParamAccess::Range as FullSet anyway. So we can drop the
800       // entire parameter like we did above.
801       // TODO(vitalybuka): Return already filtered parameters from getInfo().
802       if (C.Offset.isFullSet()) {
803         ParamAccesses.pop_back();
804         break;
805       }
806       Param.Calls.emplace_back(C.ParamNo, C.Callee->getGUID(), C.Offset);
807     }
808   }
809   return ParamAccesses;
810 }
811 
812 StackSafetyGlobalInfo::StackSafetyGlobalInfo() = default;
813 
814 StackSafetyGlobalInfo::StackSafetyGlobalInfo(
815     Module *M, std::function<const StackSafetyInfo &(Function &F)> GetSSI,
816     const ModuleSummaryIndex *Index)
817     : M(M), GetSSI(GetSSI), Index(Index) {
818   if (StackSafetyRun)
819     getInfo();
820 }
821 
822 StackSafetyGlobalInfo::StackSafetyGlobalInfo(StackSafetyGlobalInfo &&) =
823     default;
824 
825 StackSafetyGlobalInfo &
826 StackSafetyGlobalInfo::operator=(StackSafetyGlobalInfo &&) = default;
827 
828 StackSafetyGlobalInfo::~StackSafetyGlobalInfo() = default;
829 
830 bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
831   const auto &Info = getInfo();
832   return Info.SafeAllocas.count(&AI);
833 }
834 
835 void StackSafetyGlobalInfo::print(raw_ostream &O) const {
836   auto &SSI = getInfo().Info;
837   if (SSI.empty())
838     return;
839   const Module &M = *SSI.begin()->first->getParent();
840   for (auto &F : M.functions()) {
841     if (!F.isDeclaration()) {
842       SSI.find(&F)->second.print(O, F.getName(), &F);
843       O << "\n";
844     }
845   }
846 }
847 
848 LLVM_DUMP_METHOD void StackSafetyGlobalInfo::dump() const { print(dbgs()); }
849 
850 AnalysisKey StackSafetyAnalysis::Key;
851 
852 StackSafetyInfo StackSafetyAnalysis::run(Function &F,
853                                          FunctionAnalysisManager &AM) {
854   return StackSafetyInfo(&F, [&AM, &F]() -> ScalarEvolution & {
855     return AM.getResult<ScalarEvolutionAnalysis>(F);
856   });
857 }
858 
859 PreservedAnalyses StackSafetyPrinterPass::run(Function &F,
860                                               FunctionAnalysisManager &AM) {
861   OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n";
862   AM.getResult<StackSafetyAnalysis>(F).print(OS);
863   return PreservedAnalyses::all();
864 }
865 
866 char StackSafetyInfoWrapperPass::ID = 0;
867 
868 StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) {
869   initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
870 }
871 
872 void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
873   AU.addRequiredTransitive<ScalarEvolutionWrapperPass>();
874   AU.setPreservesAll();
875 }
876 
877 void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const {
878   SSI.print(O);
879 }
880 
881 bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) {
882   auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
883   SSI = {&F, [SE]() -> ScalarEvolution & { return *SE; }};
884   return false;
885 }
886 
887 AnalysisKey StackSafetyGlobalAnalysis::Key;
888 
889 StackSafetyGlobalInfo
890 StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
891   // FIXME: Lookup Module Summary.
892   FunctionAnalysisManager &FAM =
893       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
894   return {&M,
895           [&FAM](Function &F) -> const StackSafetyInfo & {
896             return FAM.getResult<StackSafetyAnalysis>(F);
897           },
898           nullptr};
899 }
900 
901 PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
902                                                     ModuleAnalysisManager &AM) {
903   OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n";
904   AM.getResult<StackSafetyGlobalAnalysis>(M).print(OS);
905   return PreservedAnalyses::all();
906 }
907 
908 char StackSafetyGlobalInfoWrapperPass::ID = 0;
909 
910 StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
911     : ModulePass(ID) {
912   initializeStackSafetyGlobalInfoWrapperPassPass(
913       *PassRegistry::getPassRegistry());
914 }
915 
916 StackSafetyGlobalInfoWrapperPass::~StackSafetyGlobalInfoWrapperPass() = default;
917 
918 void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
919                                              const Module *M) const {
920   SSGI.print(O);
921 }
922 
923 void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
924     AnalysisUsage &AU) const {
925   AU.setPreservesAll();
926   AU.addRequired<StackSafetyInfoWrapperPass>();
927 }
928 
929 bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
930   const ModuleSummaryIndex *ImportSummary = nullptr;
931   if (auto *IndexWrapperPass =
932           getAnalysisIfAvailable<ImmutableModuleSummaryIndexWrapperPass>())
933     ImportSummary = IndexWrapperPass->getIndex();
934 
935   SSGI = {&M,
936           [this](Function &F) -> const StackSafetyInfo & {
937             return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
938           },
939           ImportSummary};
940   return false;
941 }
942 
943 bool llvm::needsParamAccessSummary(const Module &M) {
944   if (StackSafetyRun)
945     return true;
946   for (auto &F : M.functions())
947     if (F.hasFnAttribute(Attribute::SanitizeMemTag))
948       return true;
949   return false;
950 }
951 
952 void llvm::generateParamAccessSummary(ModuleSummaryIndex &Index) {
953   if (!Index.hasParamAccess())
954     return;
955   const ConstantRange FullSet(FunctionSummary::ParamAccess::RangeWidth, true);
956 
957   auto CountParamAccesses = [&](auto &Stat) {
958     if (!AreStatisticsEnabled())
959       return;
960     for (auto &GVS : Index)
961       for (auto &GV : GVS.second.SummaryList)
962         if (FunctionSummary *FS = dyn_cast<FunctionSummary>(GV.get()))
963           Stat += FS->paramAccesses().size();
964   };
965 
966   CountParamAccesses(NumCombinedParamAccessesBefore);
967 
968   std::map<const FunctionSummary *, FunctionInfo<FunctionSummary>> Functions;
969 
970   // Convert the ModuleSummaryIndex to a FunctionMap
971   for (auto &GVS : Index) {
972     for (auto &GV : GVS.second.SummaryList) {
973       FunctionSummary *FS = dyn_cast<FunctionSummary>(GV.get());
974       if (!FS || FS->paramAccesses().empty())
975         continue;
976       if (FS->isLive() && FS->isDSOLocal()) {
977         FunctionInfo<FunctionSummary> FI;
978         for (auto &PS : FS->paramAccesses()) {
979           auto &US =
980               FI.Params
981                   .emplace(PS.ParamNo, FunctionSummary::ParamAccess::RangeWidth)
982                   .first->second;
983           US.Range = PS.Use;
984           for (auto &Call : PS.Calls) {
985             assert(!Call.Offsets.isFullSet());
986             FunctionSummary *S = resolveCallee(
987                 Index.findSummaryInModule(Call.Callee, FS->modulePath()));
988             ++NumCombinedCalleeLookupTotal;
989             if (!S) {
990               ++NumCombinedCalleeLookupFailed;
991               US.Range = FullSet;
992               US.Calls.clear();
993               break;
994             }
995             US.Calls.emplace_back(S, Call.ParamNo, Call.Offsets);
996           }
997         }
998         Functions.emplace(FS, std::move(FI));
999       }
1000       // Reset data for all summaries. Alive and DSO local will be set back from
1001       // of data flow results below. Anything else will not be accessed
1002       // by ThinLTO backend, so we can save on bitcode size.
1003       FS->setParamAccesses({});
1004     }
1005   }
1006   NumCombinedDataFlowNodes += Functions.size();
1007   StackSafetyDataFlowAnalysis<FunctionSummary> SSDFA(
1008       FunctionSummary::ParamAccess::RangeWidth, std::move(Functions));
1009   for (auto &KV : SSDFA.run()) {
1010     std::vector<FunctionSummary::ParamAccess> NewParams;
1011     NewParams.reserve(KV.second.Params.size());
1012     for (auto &Param : KV.second.Params) {
1013       // It's not needed as FullSet is processed the same as a missing value.
1014       if (Param.second.Range.isFullSet())
1015         continue;
1016       NewParams.emplace_back();
1017       FunctionSummary::ParamAccess &New = NewParams.back();
1018       New.ParamNo = Param.first;
1019       New.Use = Param.second.Range; // Only range is needed.
1020     }
1021     const_cast<FunctionSummary *>(KV.first)->setParamAccesses(
1022         std::move(NewParams));
1023   }
1024 
1025   CountParamAccesses(NumCombinedParamAccessesAfter);
1026 }
1027 
1028 static const char LocalPassArg[] = "stack-safety-local";
1029 static const char LocalPassName[] = "Stack Safety Local Analysis";
1030 INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
1031                       false, true)
1032 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
1033 INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
1034                     false, true)
1035 
1036 static const char GlobalPassName[] = "Stack Safety Analysis";
1037 INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
1038                       GlobalPassName, false, true)
1039 INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass)
1040 INITIALIZE_PASS_DEPENDENCY(ImmutableModuleSummaryIndexWrapperPass)
1041 INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
1042                     GlobalPassName, false, true)
1043