1 //===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "llvm/Analysis/StackSafetyAnalysis.h"
12 #include "llvm/ADT/APInt.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/Statistic.h"
15 #include "llvm/Analysis/ModuleSummaryAnalysis.h"
16 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
17 #include "llvm/IR/ConstantRange.h"
18 #include "llvm/IR/DerivedTypes.h"
19 #include "llvm/IR/GlobalValue.h"
20 #include "llvm/IR/InstIterator.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <algorithm>
29 #include <memory>
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "stack-safety"
34 
35 STATISTIC(NumAllocaStackSafe, "Number of safe allocas");
36 STATISTIC(NumAllocaTotal, "Number of total allocas");
37 
38 static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations",
39                                              cl::init(20), cl::Hidden);
40 
41 static cl::opt<bool> StackSafetyPrint("stack-safety-print", cl::init(false),
42                                       cl::Hidden);
43 
44 static cl::opt<bool> StackSafetyRun("stack-safety-run", cl::init(false),
45                                     cl::Hidden);
46 
47 namespace {
48 
49 /// Describes use of address in as a function call argument.
50 template <typename CalleeTy> struct CallInfo {
51   /// Function being called.
52   const CalleeTy *Callee = nullptr;
53   /// Index of argument which pass address.
54   size_t ParamNo = 0;
55   // Offset range of address from base address (alloca or calling function
56   // argument).
57   // Range should never set to empty-set, that is an invalid access range
58   // that can cause empty-set to be propagated with ConstantRange::add
59   ConstantRange Offset;
60   CallInfo(const CalleeTy *Callee, size_t ParamNo, ConstantRange Offset)
61       : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {}
62 };
63 
64 template <typename CalleeTy>
65 raw_ostream &operator<<(raw_ostream &OS, const CallInfo<CalleeTy> &P) {
66   return OS << "@" << P.Callee->getName() << "(arg" << P.ParamNo << ", "
67             << P.Offset << ")";
68 }
69 
70 /// Describe uses of address (alloca or parameter) inside of the function.
71 template <typename CalleeTy> struct UseInfo {
72   // Access range if the address (alloca or parameters).
73   // It is allowed to be empty-set when there are no known accesses.
74   ConstantRange Range;
75 
76   // List of calls which pass address as an argument.
77   SmallVector<CallInfo<CalleeTy>, 4> Calls;
78 
79   UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
80 
81   void updateRange(const ConstantRange &R) {
82     assert(!R.isUpperSignWrapped());
83     Range = Range.unionWith(R);
84     assert(!Range.isUpperSignWrapped());
85   }
86 };
87 
88 template <typename CalleeTy>
89 raw_ostream &operator<<(raw_ostream &OS, const UseInfo<CalleeTy> &U) {
90   OS << U.Range;
91   for (auto &Call : U.Calls)
92     OS << ", " << Call;
93   return OS;
94 }
95 
96 // Check if we should bailout for such ranges.
97 bool isUnsafe(const ConstantRange &R) {
98   return R.isEmptySet() || R.isFullSet() || R.isUpperSignWrapped();
99 }
100 
101 ConstantRange addOverflowNever(const ConstantRange &L, const ConstantRange &R) {
102   if (L.signedAddMayOverflow(R) !=
103       ConstantRange::OverflowResult::NeverOverflows)
104     return ConstantRange(L.getBitWidth(), true);
105   return L.add(R);
106 }
107 
108 /// Calculate the allocation size of a given alloca. Returns empty range
109 // in case of confution.
110 ConstantRange getStaticAllocaSizeRange(const AllocaInst &AI) {
111   const DataLayout &DL = AI.getModule()->getDataLayout();
112   TypeSize TS = DL.getTypeAllocSize(AI.getAllocatedType());
113   unsigned PointerSize = DL.getMaxPointerSizeInBits();
114   // Fallback to empty range for alloca size.
115   ConstantRange R = ConstantRange::getEmpty(PointerSize);
116   if (TS.isScalable())
117     return R;
118   APInt APSize(PointerSize, TS.getFixedSize(), true);
119   if (APSize.isNonPositive())
120     return R;
121   if (AI.isArrayAllocation()) {
122     const auto *C = dyn_cast<ConstantInt>(AI.getArraySize());
123     if (!C)
124       return R;
125     bool Overflow = false;
126     APInt Mul = C->getValue();
127     if (Mul.isNonPositive())
128       return R;
129     Mul = Mul.sextOrTrunc(PointerSize);
130     APSize = APSize.smul_ov(Mul, Overflow);
131     if (Overflow)
132       return R;
133   }
134   R = ConstantRange(APInt::getNullValue(PointerSize), APSize);
135   assert(!isUnsafe(R));
136   return R;
137 }
138 
139 template <typename CalleeTy> struct FunctionInfo {
140   std::map<const AllocaInst *, UseInfo<CalleeTy>> Allocas;
141   std::map<uint32_t, UseInfo<CalleeTy>> Params;
142   // TODO: describe return value as depending on one or more of its arguments.
143 
144   // StackSafetyDataFlowAnalysis counter stored here for faster access.
145   int UpdateCount = 0;
146 
147   void print(raw_ostream &O, StringRef Name, const Function *F) const {
148     // TODO: Consider different printout format after
149     // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then.
150     O << "  @" << Name << ((F && F->isDSOLocal()) ? "" : " dso_preemptable")
151       << ((F && F->isInterposable()) ? " interposable" : "") << "\n";
152 
153     O << "    args uses:\n";
154     for (auto &KV : Params) {
155       O << "      ";
156       if (F)
157         O << F->getArg(KV.first)->getName();
158       else
159         O << formatv("arg{0}", KV.first);
160       O << "[]: " << KV.second << "\n";
161     }
162 
163     O << "    allocas uses:\n";
164     if (F) {
165       for (auto &I : instructions(F)) {
166         if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
167           auto &AS = Allocas.find(AI)->second;
168           O << "      " << AI->getName() << "["
169             << getStaticAllocaSizeRange(*AI).getUpper() << "]: " << AS << "\n";
170         }
171       }
172     } else {
173       assert(Allocas.empty());
174     }
175   }
176 };
177 
178 using GVToSSI = std::map<const GlobalValue *, FunctionInfo<GlobalValue>>;
179 
180 } // namespace
181 
182 struct StackSafetyInfo::InfoTy {
183   FunctionInfo<GlobalValue> Info;
184 };
185 
186 struct StackSafetyGlobalInfo::InfoTy {
187   GVToSSI Info;
188   SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
189 };
190 
191 namespace {
192 
193 class StackSafetyLocalAnalysis {
194   Function &F;
195   const DataLayout &DL;
196   ScalarEvolution &SE;
197   unsigned PointerSize = 0;
198 
199   const ConstantRange UnknownRange;
200 
201   ConstantRange offsetFrom(Value *Addr, Value *Base);
202   ConstantRange getAccessRange(Value *Addr, Value *Base,
203                                ConstantRange SizeRange);
204   ConstantRange getAccessRange(Value *Addr, Value *Base, TypeSize Size);
205   ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U,
206                                            Value *Base);
207 
208   bool analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS);
209 
210 public:
211   StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
212       : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
213         PointerSize(DL.getPointerSizeInBits()),
214         UnknownRange(PointerSize, true) {}
215 
216   // Run the transformation on the associated function.
217   FunctionInfo<GlobalValue> run();
218 };
219 
220 ConstantRange StackSafetyLocalAnalysis::offsetFrom(Value *Addr, Value *Base) {
221   if (!SE.isSCEVable(Addr->getType()) || !SE.isSCEVable(Base->getType()))
222     return UnknownRange;
223 
224   auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
225   const SCEV *AddrExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Addr), PtrTy);
226   const SCEV *BaseExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Base), PtrTy);
227   const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
228 
229   ConstantRange Offset = SE.getSignedRange(Diff);
230   if (isUnsafe(Offset))
231     return UnknownRange;
232   return Offset.sextOrTrunc(PointerSize);
233 }
234 
235 ConstantRange
236 StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
237                                          ConstantRange SizeRange) {
238   // Zero-size loads and stores do not access memory.
239   if (SizeRange.isEmptySet())
240     return ConstantRange::getEmpty(PointerSize);
241   assert(!isUnsafe(SizeRange));
242 
243   ConstantRange Offsets = offsetFrom(Addr, Base);
244   if (isUnsafe(Offsets))
245     return UnknownRange;
246 
247   Offsets = addOverflowNever(Offsets, SizeRange);
248   if (isUnsafe(Offsets))
249     return UnknownRange;
250   return Offsets;
251 }
252 
253 ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
254                                                        TypeSize Size) {
255   if (Size.isScalable())
256     return UnknownRange;
257   APInt APSize(PointerSize, Size.getFixedSize(), true);
258   if (APSize.isNegative())
259     return UnknownRange;
260   return getAccessRange(
261       Addr, Base, ConstantRange(APInt::getNullValue(PointerSize), APSize));
262 }
263 
264 ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
265     const MemIntrinsic *MI, const Use &U, Value *Base) {
266   if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
267     if (MTI->getRawSource() != U && MTI->getRawDest() != U)
268       return ConstantRange::getEmpty(PointerSize);
269   } else {
270     if (MI->getRawDest() != U)
271       return ConstantRange::getEmpty(PointerSize);
272   }
273 
274   auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
275   if (!SE.isSCEVable(MI->getLength()->getType()))
276     return UnknownRange;
277 
278   const SCEV *Expr =
279       SE.getTruncateOrZeroExtend(SE.getSCEV(MI->getLength()), CalculationTy);
280   ConstantRange Sizes = SE.getSignedRange(Expr);
281   if (Sizes.getUpper().isNegative() || isUnsafe(Sizes))
282     return UnknownRange;
283   Sizes = Sizes.sextOrTrunc(PointerSize);
284   ConstantRange SizeRange(APInt::getNullValue(PointerSize),
285                           Sizes.getUpper() - 1);
286   return getAccessRange(U, Base, SizeRange);
287 }
288 
289 /// The function analyzes all local uses of Ptr (alloca or argument) and
290 /// calculates local access range and all function calls where it was used.
291 bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
292                                               UseInfo<GlobalValue> &US) {
293   SmallPtrSet<const Value *, 16> Visited;
294   SmallVector<const Value *, 8> WorkList;
295   WorkList.push_back(Ptr);
296 
297   // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
298   while (!WorkList.empty()) {
299     const Value *V = WorkList.pop_back_val();
300     for (const Use &UI : V->uses()) {
301       const auto *I = cast<const Instruction>(UI.getUser());
302       assert(V == UI.get());
303 
304       switch (I->getOpcode()) {
305       case Instruction::Load: {
306         US.updateRange(
307             getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
308         break;
309       }
310 
311       case Instruction::VAArg:
312         // "va-arg" from a pointer is safe.
313         break;
314       case Instruction::Store: {
315         if (V == I->getOperand(0)) {
316           // Stored the pointer - conservatively assume it may be unsafe.
317           US.updateRange(UnknownRange);
318           return false;
319         }
320         US.updateRange(getAccessRange(
321             UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
322         break;
323       }
324 
325       case Instruction::Ret:
326         // Information leak.
327         // FIXME: Process parameters correctly. This is a leak only if we return
328         // alloca.
329         US.updateRange(UnknownRange);
330         return false;
331 
332       case Instruction::Call:
333       case Instruction::Invoke: {
334         if (I->isLifetimeStartOrEnd())
335           break;
336 
337         if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
338           US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr));
339           break;
340         }
341 
342         const auto &CB = cast<CallBase>(*I);
343         if (!CB.isArgOperand(&UI)) {
344           US.updateRange(UnknownRange);
345           return false;
346         }
347 
348         unsigned ArgNo = CB.getArgOperandNo(&UI);
349         if (CB.isByValArgument(ArgNo)) {
350           US.updateRange(getAccessRange(
351               UI, Ptr, DL.getTypeStoreSize(CB.getParamByValType(ArgNo))));
352           break;
353         }
354 
355         // FIXME: consult devirt?
356         // Do not follow aliases, otherwise we could inadvertently follow
357         // dso_preemptable aliases or aliases with interposable linkage.
358         const GlobalValue *Callee =
359             dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
360         if (!Callee) {
361           US.updateRange(UnknownRange);
362           return false;
363         }
364 
365         assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee));
366         US.Calls.emplace_back(Callee, ArgNo, offsetFrom(UI, Ptr));
367         break;
368       }
369 
370       default:
371         if (Visited.insert(I).second)
372           WorkList.push_back(cast<const Instruction>(I));
373       }
374     }
375   }
376 
377   return true;
378 }
379 
380 FunctionInfo<GlobalValue> StackSafetyLocalAnalysis::run() {
381   FunctionInfo<GlobalValue> Info;
382   assert(!F.isDeclaration() &&
383          "Can't run StackSafety on a function declaration");
384 
385   LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n");
386 
387   for (auto &I : instructions(F)) {
388     if (auto *AI = dyn_cast<AllocaInst>(&I)) {
389       auto &UI = Info.Allocas.emplace(AI, PointerSize).first->second;
390       analyzeAllUses(AI, UI);
391     }
392   }
393 
394   for (Argument &A : make_range(F.arg_begin(), F.arg_end())) {
395     // Non pointers and bypass arguments are not going to be used in any global
396     // processing.
397     if (A.getType()->isPointerTy() && !A.hasByValAttr()) {
398       auto &UI = Info.Params.emplace(A.getArgNo(), PointerSize).first->second;
399       analyzeAllUses(&A, UI);
400     }
401   }
402 
403   LLVM_DEBUG(Info.print(dbgs(), F.getName(), &F));
404   LLVM_DEBUG(dbgs() << "[StackSafety] done\n");
405   return Info;
406 }
407 
408 template <typename CalleeTy> class StackSafetyDataFlowAnalysis {
409   using FunctionMap = std::map<const CalleeTy *, FunctionInfo<CalleeTy>>;
410 
411   FunctionMap Functions;
412   const ConstantRange UnknownRange;
413 
414   // Callee-to-Caller multimap.
415   DenseMap<const CalleeTy *, SmallVector<const CalleeTy *, 4>> Callers;
416   SetVector<const CalleeTy *> WorkList;
417 
418   bool updateOneUse(UseInfo<CalleeTy> &US, bool UpdateToFullSet);
419   void updateOneNode(const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS);
420   void updateOneNode(const CalleeTy *Callee) {
421     updateOneNode(Callee, Functions.find(Callee)->second);
422   }
423   void updateAllNodes() {
424     for (auto &F : Functions)
425       updateOneNode(F.first, F.second);
426   }
427   void runDataFlow();
428 #ifndef NDEBUG
429   void verifyFixedPoint();
430 #endif
431 
432 public:
433   StackSafetyDataFlowAnalysis(uint32_t PointerBitWidth, FunctionMap Functions)
434       : Functions(std::move(Functions)),
435         UnknownRange(ConstantRange::getFull(PointerBitWidth)) {}
436 
437   const FunctionMap &run();
438 
439   ConstantRange getArgumentAccessRange(const CalleeTy *Callee, unsigned ParamNo,
440                                        const ConstantRange &Offsets) const;
441 };
442 
443 template <typename CalleeTy>
444 ConstantRange StackSafetyDataFlowAnalysis<CalleeTy>::getArgumentAccessRange(
445     const CalleeTy *Callee, unsigned ParamNo,
446     const ConstantRange &Offsets) const {
447   auto FnIt = Functions.find(Callee);
448   // Unknown callee (outside of LTO domain or an indirect call).
449   if (FnIt == Functions.end())
450     return UnknownRange;
451   auto &FS = FnIt->second;
452   auto ParamIt = FS.Params.find(ParamNo);
453   if (ParamIt == FS.Params.end())
454     return UnknownRange;
455   auto &Access = ParamIt->second.Range;
456   if (Access.isEmptySet())
457     return Access;
458   if (Access.isFullSet())
459     return UnknownRange;
460   return addOverflowNever(Access, Offsets);
461 }
462 
463 template <typename CalleeTy>
464 bool StackSafetyDataFlowAnalysis<CalleeTy>::updateOneUse(UseInfo<CalleeTy> &US,
465                                                          bool UpdateToFullSet) {
466   bool Changed = false;
467   for (auto &CS : US.Calls) {
468     assert(!CS.Offset.isEmptySet() &&
469            "Param range can't be empty-set, invalid offset range");
470 
471     ConstantRange CalleeRange =
472         getArgumentAccessRange(CS.Callee, CS.ParamNo, CS.Offset);
473     if (!US.Range.contains(CalleeRange)) {
474       Changed = true;
475       if (UpdateToFullSet)
476         US.Range = UnknownRange;
477       else
478         US.Range = US.Range.unionWith(CalleeRange);
479     }
480   }
481   return Changed;
482 }
483 
484 template <typename CalleeTy>
485 void StackSafetyDataFlowAnalysis<CalleeTy>::updateOneNode(
486     const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS) {
487   bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations;
488   bool Changed = false;
489   for (auto &KV : FS.Params)
490     Changed |= updateOneUse(KV.second, UpdateToFullSet);
491 
492   if (Changed) {
493     LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount
494                       << (UpdateToFullSet ? ", full-set" : "") << "] " << &FS
495                       << "\n");
496     // Callers of this function may need updating.
497     for (auto &CallerID : Callers[Callee])
498       WorkList.insert(CallerID);
499 
500     ++FS.UpdateCount;
501   }
502 }
503 
504 template <typename CalleeTy>
505 void StackSafetyDataFlowAnalysis<CalleeTy>::runDataFlow() {
506   SmallVector<const CalleeTy *, 16> Callees;
507   for (auto &F : Functions) {
508     Callees.clear();
509     auto &FS = F.second;
510     for (auto &KV : FS.Params)
511       for (auto &CS : KV.second.Calls)
512         Callees.push_back(CS.Callee);
513 
514     llvm::sort(Callees);
515     Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end());
516 
517     for (auto &Callee : Callees)
518       Callers[Callee].push_back(F.first);
519   }
520 
521   updateAllNodes();
522 
523   while (!WorkList.empty()) {
524     const CalleeTy *Callee = WorkList.back();
525     WorkList.pop_back();
526     updateOneNode(Callee);
527   }
528 }
529 
530 #ifndef NDEBUG
531 template <typename CalleeTy>
532 void StackSafetyDataFlowAnalysis<CalleeTy>::verifyFixedPoint() {
533   WorkList.clear();
534   updateAllNodes();
535   assert(WorkList.empty());
536 }
537 #endif
538 
539 template <typename CalleeTy>
540 const typename StackSafetyDataFlowAnalysis<CalleeTy>::FunctionMap &
541 StackSafetyDataFlowAnalysis<CalleeTy>::run() {
542   runDataFlow();
543   LLVM_DEBUG(verifyFixedPoint());
544   return Functions;
545 }
546 
547 FunctionSummary *resolveCallee(GlobalValueSummary *S) {
548   while (S) {
549     if (!S->isLive() || !S->isDSOLocal())
550       return nullptr;
551     if (FunctionSummary *FS = dyn_cast<FunctionSummary>(S))
552       return FS;
553     AliasSummary *AS = dyn_cast<AliasSummary>(S);
554     if (!AS)
555       return nullptr;
556     S = AS->getBaseObject();
557     if (S == AS)
558       return nullptr;
559   }
560   return nullptr;
561 }
562 
563 const Function *findCalleeInModule(const GlobalValue *GV) {
564   while (GV) {
565     if (GV->isDeclaration() || GV->isInterposable() || !GV->isDSOLocal())
566       return nullptr;
567     if (const Function *F = dyn_cast<Function>(GV))
568       return F;
569     const GlobalAlias *A = dyn_cast<GlobalAlias>(GV);
570     if (!A)
571       return nullptr;
572     GV = A->getBaseObject();
573     if (GV == A)
574       return nullptr;
575   }
576   return nullptr;
577 }
578 
579 GlobalValueSummary *getGlobalValueSummary(const ModuleSummaryIndex *Index,
580                                           uint64_t ValueGUID) {
581   auto VI = Index->getValueInfo(ValueGUID);
582   if (!VI || VI.getSummaryList().empty())
583     return nullptr;
584   assert(VI.getSummaryList().size() == 1);
585   auto &Summary = VI.getSummaryList()[0];
586   return Summary.get();
587 }
588 
589 const ConstantRange *findParamAccess(const FunctionSummary &FS,
590                                      uint32_t ParamNo) {
591   assert(FS.isLive());
592   assert(FS.isDSOLocal());
593   for (auto &PS : FS.paramAccesses())
594     if (ParamNo == PS.ParamNo)
595       return &PS.Use;
596   return nullptr;
597 }
598 
599 void resolveAllCalls(UseInfo<GlobalValue> &Use,
600                      const ModuleSummaryIndex *Index) {
601   ConstantRange FullSet(Use.Range.getBitWidth(), true);
602   for (auto &C : Use.Calls) {
603     const Function *F = findCalleeInModule(C.Callee);
604     if (F) {
605       C.Callee = F;
606       continue;
607     }
608 
609     if (!Index)
610       return Use.updateRange(FullSet);
611     GlobalValueSummary *GVS = getGlobalValueSummary(Index, C.Callee->getGUID());
612 
613     FunctionSummary *FS = resolveCallee(GVS);
614     if (!FS)
615       return Use.updateRange(FullSet);
616     const ConstantRange *Found = findParamAccess(*FS, C.ParamNo);
617     if (!Found)
618       return Use.updateRange(FullSet);
619     ConstantRange Access = Found->sextOrTrunc(Use.Range.getBitWidth());
620     Use.updateRange(addOverflowNever(Access, C.Offset));
621     C.Callee = nullptr;
622   }
623 
624   Use.Calls.erase(std::remove_if(Use.Calls.begin(), Use.Calls.end(),
625                                  [](auto &T) { return !T.Callee; }),
626                   Use.Calls.end());
627 }
628 
629 GVToSSI createGlobalStackSafetyInfo(
630     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions,
631     const ModuleSummaryIndex *Index) {
632   GVToSSI SSI;
633   if (Functions.empty())
634     return SSI;
635 
636   // FIXME: Simplify printing and remove copying here.
637   auto Copy = Functions;
638 
639   for (auto &FnKV : Copy)
640     for (auto &KV : FnKV.second.Params)
641       resolveAllCalls(KV.second, Index);
642 
643   uint32_t PointerSize = Copy.begin()
644                              ->first->getParent()
645                              ->getDataLayout()
646                              .getMaxPointerSizeInBits();
647   StackSafetyDataFlowAnalysis<GlobalValue> SSDFA(PointerSize, std::move(Copy));
648 
649   for (auto &F : SSDFA.run()) {
650     auto FI = F.second;
651     auto &SrcF = Functions[F.first];
652     for (auto &KV : FI.Allocas) {
653       auto &A = KV.second;
654       resolveAllCalls(A, Index);
655       for (auto &C : A.Calls) {
656         A.updateRange(
657             SSDFA.getArgumentAccessRange(C.Callee, C.ParamNo, C.Offset));
658       }
659       // FIXME: This is needed only to preserve calls in print() results.
660       A.Calls = SrcF.Allocas.find(KV.first)->second.Calls;
661     }
662     for (auto &KV : FI.Params) {
663       auto &P = KV.second;
664       P.Calls = SrcF.Params.find(KV.first)->second.Calls;
665     }
666     SSI[F.first] = std::move(FI);
667   }
668 
669   return SSI;
670 }
671 
672 } // end anonymous namespace
673 
674 StackSafetyInfo::StackSafetyInfo() = default;
675 
676 StackSafetyInfo::StackSafetyInfo(Function *F,
677                                  std::function<ScalarEvolution &()> GetSE)
678     : F(F), GetSE(GetSE) {}
679 
680 StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default;
681 
682 StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default;
683 
684 StackSafetyInfo::~StackSafetyInfo() = default;
685 
686 const StackSafetyInfo::InfoTy &StackSafetyInfo::getInfo() const {
687   if (!Info) {
688     StackSafetyLocalAnalysis SSLA(*F, GetSE());
689     Info.reset(new InfoTy{SSLA.run()});
690   }
691   return *Info;
692 }
693 
694 void StackSafetyInfo::print(raw_ostream &O) const {
695   getInfo().Info.print(O, F->getName(), dyn_cast<Function>(F));
696 }
697 
698 const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
699   if (!Info) {
700     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions;
701     for (auto &F : M->functions()) {
702       if (!F.isDeclaration()) {
703         auto FI = GetSSI(F).getInfo().Info;
704         Functions.emplace(&F, std::move(FI));
705       }
706     }
707     Info.reset(new InfoTy{
708         createGlobalStackSafetyInfo(std::move(Functions), Index), {}});
709     for (auto &FnKV : Info->Info) {
710       for (auto &KV : FnKV.second.Allocas) {
711         ++NumAllocaTotal;
712         const AllocaInst *AI = KV.first;
713         if (getStaticAllocaSizeRange(*AI).contains(KV.second.Range)) {
714           Info->SafeAllocas.insert(AI);
715           ++NumAllocaStackSafe;
716         }
717       }
718     }
719     if (StackSafetyPrint)
720       print(errs());
721   }
722   return *Info;
723 }
724 
725 // Converts a StackSafetyFunctionInfo to the relevant FunctionSummary
726 // constructor fields
727 std::vector<FunctionSummary::ParamAccess>
728 StackSafetyInfo::getParamAccesses() const {
729   assert(needsParamAccessSummary(*F->getParent()));
730 
731   std::vector<FunctionSummary::ParamAccess> ParamAccesses;
732   for (const auto &KV : getInfo().Info.Params) {
733     auto &PS = KV.second;
734     if (PS.Range.isFullSet())
735       continue;
736 
737     ParamAccesses.emplace_back(KV.first, PS.Range);
738     FunctionSummary::ParamAccess &Param = ParamAccesses.back();
739 
740     Param.Calls.reserve(PS.Calls.size());
741     for (auto &C : PS.Calls) {
742       if (C.Offset.isFullSet()) {
743         ParamAccesses.pop_back();
744         break;
745       }
746       Param.Calls.emplace_back(C.ParamNo, C.Callee->getGUID(), C.Offset);
747     }
748   }
749   return ParamAccesses;
750 }
751 
752 StackSafetyGlobalInfo::StackSafetyGlobalInfo() = default;
753 
754 StackSafetyGlobalInfo::StackSafetyGlobalInfo(
755     Module *M, std::function<const StackSafetyInfo &(Function &F)> GetSSI,
756     const ModuleSummaryIndex *Index)
757     : M(M), GetSSI(GetSSI), Index(Index) {
758   if (StackSafetyRun)
759     getInfo();
760 }
761 
762 StackSafetyGlobalInfo::StackSafetyGlobalInfo(StackSafetyGlobalInfo &&) =
763     default;
764 
765 StackSafetyGlobalInfo &
766 StackSafetyGlobalInfo::operator=(StackSafetyGlobalInfo &&) = default;
767 
768 StackSafetyGlobalInfo::~StackSafetyGlobalInfo() = default;
769 
770 bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
771   const auto &Info = getInfo();
772   return Info.SafeAllocas.count(&AI);
773 }
774 
775 void StackSafetyGlobalInfo::print(raw_ostream &O) const {
776   auto &SSI = getInfo().Info;
777   if (SSI.empty())
778     return;
779   const Module &M = *SSI.begin()->first->getParent();
780   for (auto &F : M.functions()) {
781     if (!F.isDeclaration()) {
782       SSI.find(&F)->second.print(O, F.getName(), &F);
783       O << "\n";
784     }
785   }
786 }
787 
788 LLVM_DUMP_METHOD void StackSafetyGlobalInfo::dump() const { print(dbgs()); }
789 
790 AnalysisKey StackSafetyAnalysis::Key;
791 
792 StackSafetyInfo StackSafetyAnalysis::run(Function &F,
793                                          FunctionAnalysisManager &AM) {
794   return StackSafetyInfo(&F, [&AM, &F]() -> ScalarEvolution & {
795     return AM.getResult<ScalarEvolutionAnalysis>(F);
796   });
797 }
798 
799 PreservedAnalyses StackSafetyPrinterPass::run(Function &F,
800                                               FunctionAnalysisManager &AM) {
801   OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n";
802   AM.getResult<StackSafetyAnalysis>(F).print(OS);
803   return PreservedAnalyses::all();
804 }
805 
806 char StackSafetyInfoWrapperPass::ID = 0;
807 
808 StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) {
809   initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
810 }
811 
812 void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
813   AU.addRequiredTransitive<ScalarEvolutionWrapperPass>();
814   AU.setPreservesAll();
815 }
816 
817 void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const {
818   SSI.print(O);
819 }
820 
821 bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) {
822   auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
823   SSI = {&F, [SE]() -> ScalarEvolution & { return *SE; }};
824   return false;
825 }
826 
827 AnalysisKey StackSafetyGlobalAnalysis::Key;
828 
829 StackSafetyGlobalInfo
830 StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
831   // FIXME: Lookup Module Summary.
832   FunctionAnalysisManager &FAM =
833       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
834   return {&M,
835           [&FAM](Function &F) -> const StackSafetyInfo & {
836             return FAM.getResult<StackSafetyAnalysis>(F);
837           },
838           nullptr};
839 }
840 
841 PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
842                                                     ModuleAnalysisManager &AM) {
843   OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n";
844   AM.getResult<StackSafetyGlobalAnalysis>(M).print(OS);
845   return PreservedAnalyses::all();
846 }
847 
848 char StackSafetyGlobalInfoWrapperPass::ID = 0;
849 
850 StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
851     : ModulePass(ID) {
852   initializeStackSafetyGlobalInfoWrapperPassPass(
853       *PassRegistry::getPassRegistry());
854 }
855 
856 StackSafetyGlobalInfoWrapperPass::~StackSafetyGlobalInfoWrapperPass() = default;
857 
858 void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
859                                              const Module *M) const {
860   SSGI.print(O);
861 }
862 
863 void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
864     AnalysisUsage &AU) const {
865   AU.setPreservesAll();
866   AU.addRequired<StackSafetyInfoWrapperPass>();
867 }
868 
869 bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
870   const ModuleSummaryIndex *ImportSummary = nullptr;
871   if (auto *IndexWrapperPass =
872           getAnalysisIfAvailable<ImmutableModuleSummaryIndexWrapperPass>())
873     ImportSummary = IndexWrapperPass->getIndex();
874 
875   SSGI = {&M,
876           [this](Function &F) -> const StackSafetyInfo & {
877             return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
878           },
879           ImportSummary};
880   return false;
881 }
882 
883 bool llvm::needsParamAccessSummary(const Module &M) {
884   if (StackSafetyRun)
885     return true;
886   for (auto &F : M.functions())
887     if (F.hasFnAttribute(Attribute::SanitizeMemTag))
888       return true;
889   return false;
890 }
891 
892 void llvm::generateParamAccessSummary(ModuleSummaryIndex &Index) {
893   const ConstantRange FullSet(FunctionSummary::ParamAccess::RangeWidth, true);
894   std::map<const FunctionSummary *, FunctionInfo<FunctionSummary>> Functions;
895 
896   // Convert the ModuleSummaryIndex to a FunctionMap
897   for (auto &GVS : Index) {
898     for (auto &GV : GVS.second.SummaryList) {
899       FunctionSummary *FS = dyn_cast<FunctionSummary>(GV.get());
900       if (!FS)
901         continue;
902       if (FS->isLive() && FS->isDSOLocal()) {
903         FunctionInfo<FunctionSummary> FI;
904         for (auto &PS : FS->paramAccesses()) {
905           auto &US =
906               FI.Params
907                   .emplace(PS.ParamNo, FunctionSummary::ParamAccess::RangeWidth)
908                   .first->second;
909           US.Range = PS.Use;
910           for (auto &Call : PS.Calls) {
911             assert(!Call.Offsets.isFullSet());
912             FunctionSummary *S = resolveCallee(
913                 Index.findSummaryInModule(Call.Callee, FS->modulePath()));
914             if (!S) {
915               US.Range = FullSet;
916               US.Calls.clear();
917               break;
918             }
919             US.Calls.emplace_back(S, Call.ParamNo, Call.Offsets);
920           }
921         }
922         Functions.emplace(FS, std::move(FI));
923       }
924       // Reset data for all summaries. Alive and DSO local will be set back from
925       // of data flow results below. Anything else will not be accessed
926       // by ThinLTO backend, so we can save on bitcode size.
927       FS->setParamAccesses({});
928     }
929   }
930   StackSafetyDataFlowAnalysis<FunctionSummary> SSDFA(
931       FunctionSummary::ParamAccess::RangeWidth, std::move(Functions));
932   for (auto &KV : SSDFA.run()) {
933     std::vector<FunctionSummary::ParamAccess> NewParams;
934     NewParams.reserve(KV.second.Params.size());
935     for (auto &Param : KV.second.Params) {
936       NewParams.emplace_back();
937       FunctionSummary::ParamAccess &New = NewParams.back();
938       New.ParamNo = Param.first;
939       New.Use = Param.second.Range; // Only range is needed.
940     }
941     const_cast<FunctionSummary *>(KV.first)->setParamAccesses(
942         std::move(NewParams));
943   }
944 }
945 
946 static const char LocalPassArg[] = "stack-safety-local";
947 static const char LocalPassName[] = "Stack Safety Local Analysis";
948 INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
949                       false, true)
950 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
951 INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
952                     false, true)
953 
954 static const char GlobalPassName[] = "Stack Safety Analysis";
955 INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
956                       GlobalPassName, false, true)
957 INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass)
958 INITIALIZE_PASS_DEPENDENCY(ImmutableModuleSummaryIndexWrapperPass)
959 INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
960                     GlobalPassName, false, true)
961