1 //===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "llvm/Analysis/StackSafetyAnalysis.h"
12 #include "llvm/ADT/APInt.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/ModuleSummaryAnalysis.h"
17 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
18 #include "llvm/Analysis/StackLifetime.h"
19 #include "llvm/IR/ConstantRange.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/GlobalValue.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <memory>
32 
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "stack-safety"
36 
37 STATISTIC(NumAllocaStackSafe, "Number of safe allocas");
38 STATISTIC(NumAllocaTotal, "Number of total allocas");
39 
40 STATISTIC(NumCombinedCalleeLookupTotal,
41           "Number of total callee lookups on combined index.");
42 STATISTIC(NumCombinedCalleeLookupFailed,
43           "Number of failed callee lookups on combined index.");
44 STATISTIC(NumModuleCalleeLookupTotal,
45           "Number of total callee lookups on module index.");
46 STATISTIC(NumModuleCalleeLookupFailed,
47           "Number of failed callee lookups on module index.");
48 STATISTIC(NumCombinedParamAccessesBefore,
49           "Number of total param accesses before generateParamAccessSummary.");
50 STATISTIC(NumCombinedParamAccessesAfter,
51           "Number of total param accesses after generateParamAccessSummary.");
52 STATISTIC(NumCombinedDataFlowNodes,
53           "Number of total nodes in combined index for dataflow processing.");
54 
55 static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations",
56                                              cl::init(20), cl::Hidden);
57 
58 static cl::opt<bool> StackSafetyPrint("stack-safety-print", cl::init(false),
59                                       cl::Hidden);
60 
61 static cl::opt<bool> StackSafetyRun("stack-safety-run", cl::init(false),
62                                     cl::Hidden);
63 
64 namespace {
65 
66 // Check if we should bailout for such ranges.
67 bool isUnsafe(const ConstantRange &R) {
68   return R.isEmptySet() || R.isFullSet() || R.isUpperSignWrapped();
69 }
70 
71 ConstantRange addOverflowNever(const ConstantRange &L, const ConstantRange &R) {
72   assert(!L.isSignWrappedSet());
73   assert(!R.isSignWrappedSet());
74   if (L.signedAddMayOverflow(R) !=
75       ConstantRange::OverflowResult::NeverOverflows)
76     return ConstantRange::getFull(L.getBitWidth());
77   ConstantRange Result = L.add(R);
78   assert(!Result.isSignWrappedSet());
79   return Result;
80 }
81 
82 ConstantRange unionNoWrap(const ConstantRange &L, const ConstantRange &R) {
83   assert(!L.isSignWrappedSet());
84   assert(!R.isSignWrappedSet());
85   auto Result = L.unionWith(R);
86   // Two non-wrapped sets can produce wrapped.
87   if (Result.isSignWrappedSet())
88     Result = ConstantRange::getFull(Result.getBitWidth());
89   return Result;
90 }
91 
92 /// Describes use of address in as a function call argument.
93 template <typename CalleeTy> struct CallInfo {
94   /// Function being called.
95   const CalleeTy *Callee = nullptr;
96   /// Index of argument which pass address.
97   size_t ParamNo = 0;
98 
99   CallInfo(const CalleeTy *Callee, size_t ParamNo)
100       : Callee(Callee), ParamNo(ParamNo) {}
101 
102   struct Less {
103     bool operator()(const CallInfo &L, const CallInfo &R) const {
104       return std::tie(L.Callee, L.ParamNo) < std::tie(R.Callee, R.ParamNo);
105     }
106   };
107 };
108 
109 /// Describe uses of address (alloca or parameter) inside of the function.
110 template <typename CalleeTy> struct UseInfo {
111   // Access range if the address (alloca or parameters).
112   // It is allowed to be empty-set when there are no known accesses.
113   ConstantRange Range;
114 
115   // List of calls which pass address as an argument.
116   // Value is offset range of address from base address (alloca or calling
117   // function argument). Range should never set to empty-set, that is an invalid
118   // access range that can cause empty-set to be propagated with
119   // ConstantRange::add
120   std::map<CallInfo<CalleeTy>, ConstantRange, typename CallInfo<CalleeTy>::Less>
121       Calls;
122 
123   UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
124 
125   void updateRange(const ConstantRange &R) { Range = unionNoWrap(Range, R); }
126 };
127 
128 template <typename CalleeTy>
129 raw_ostream &operator<<(raw_ostream &OS, const UseInfo<CalleeTy> &U) {
130   OS << U.Range;
131   for (auto &Call : U.Calls)
132     OS << ", "
133        << "@" << Call.first.Callee->getName() << "(arg" << Call.first.ParamNo
134        << ", " << Call.second << ")";
135   return OS;
136 }
137 
138 /// Calculate the allocation size of a given alloca. Returns empty range
139 // in case of confution.
140 ConstantRange getStaticAllocaSizeRange(const AllocaInst &AI) {
141   const DataLayout &DL = AI.getModule()->getDataLayout();
142   TypeSize TS = DL.getTypeAllocSize(AI.getAllocatedType());
143   unsigned PointerSize = DL.getMaxPointerSizeInBits();
144   // Fallback to empty range for alloca size.
145   ConstantRange R = ConstantRange::getEmpty(PointerSize);
146   if (TS.isScalable())
147     return R;
148   APInt APSize(PointerSize, TS.getFixedSize(), true);
149   if (APSize.isNonPositive())
150     return R;
151   if (AI.isArrayAllocation()) {
152     const auto *C = dyn_cast<ConstantInt>(AI.getArraySize());
153     if (!C)
154       return R;
155     bool Overflow = false;
156     APInt Mul = C->getValue();
157     if (Mul.isNonPositive())
158       return R;
159     Mul = Mul.sextOrTrunc(PointerSize);
160     APSize = APSize.smul_ov(Mul, Overflow);
161     if (Overflow)
162       return R;
163   }
164   R = ConstantRange(APInt::getNullValue(PointerSize), APSize);
165   assert(!isUnsafe(R));
166   return R;
167 }
168 
169 template <typename CalleeTy> struct FunctionInfo {
170   std::map<const AllocaInst *, UseInfo<CalleeTy>> Allocas;
171   std::map<uint32_t, UseInfo<CalleeTy>> Params;
172   // TODO: describe return value as depending on one or more of its arguments.
173 
174   // StackSafetyDataFlowAnalysis counter stored here for faster access.
175   int UpdateCount = 0;
176 
177   void print(raw_ostream &O, StringRef Name, const Function *F) const {
178     // TODO: Consider different printout format after
179     // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then.
180     O << "  @" << Name << ((F && F->isDSOLocal()) ? "" : " dso_preemptable")
181       << ((F && F->isInterposable()) ? " interposable" : "") << "\n";
182 
183     O << "    args uses:\n";
184     for (auto &KV : Params) {
185       O << "      ";
186       if (F)
187         O << F->getArg(KV.first)->getName();
188       else
189         O << formatv("arg{0}", KV.first);
190       O << "[]: " << KV.second << "\n";
191     }
192 
193     O << "    allocas uses:\n";
194     if (F) {
195       for (auto &I : instructions(F)) {
196         if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
197           auto &AS = Allocas.find(AI)->second;
198           O << "      " << AI->getName() << "["
199             << getStaticAllocaSizeRange(*AI).getUpper() << "]: " << AS << "\n";
200         }
201       }
202     } else {
203       assert(Allocas.empty());
204     }
205     O << "\n";
206   }
207 };
208 
209 using GVToSSI = std::map<const GlobalValue *, FunctionInfo<GlobalValue>>;
210 
211 } // namespace
212 
213 struct StackSafetyInfo::InfoTy {
214   FunctionInfo<GlobalValue> Info;
215 };
216 
217 struct StackSafetyGlobalInfo::InfoTy {
218   GVToSSI Info;
219   SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
220 };
221 
222 namespace {
223 
224 class StackSafetyLocalAnalysis {
225   Function &F;
226   const DataLayout &DL;
227   ScalarEvolution &SE;
228   unsigned PointerSize = 0;
229 
230   const ConstantRange UnknownRange;
231 
232   ConstantRange offsetFrom(Value *Addr, Value *Base);
233   ConstantRange getAccessRange(Value *Addr, Value *Base,
234                                const ConstantRange &SizeRange);
235   ConstantRange getAccessRange(Value *Addr, Value *Base, TypeSize Size);
236   ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U,
237                                            Value *Base);
238 
239   bool analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS,
240                       const StackLifetime &SL);
241 
242 public:
243   StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
244       : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
245         PointerSize(DL.getPointerSizeInBits()),
246         UnknownRange(PointerSize, true) {}
247 
248   // Run the transformation on the associated function.
249   FunctionInfo<GlobalValue> run();
250 };
251 
252 ConstantRange StackSafetyLocalAnalysis::offsetFrom(Value *Addr, Value *Base) {
253   if (!SE.isSCEVable(Addr->getType()) || !SE.isSCEVable(Base->getType()))
254     return UnknownRange;
255 
256   auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
257   const SCEV *AddrExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Addr), PtrTy);
258   const SCEV *BaseExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Base), PtrTy);
259   const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
260 
261   ConstantRange Offset = SE.getSignedRange(Diff);
262   if (isUnsafe(Offset))
263     return UnknownRange;
264   return Offset.sextOrTrunc(PointerSize);
265 }
266 
267 ConstantRange
268 StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
269                                          const ConstantRange &SizeRange) {
270   // Zero-size loads and stores do not access memory.
271   if (SizeRange.isEmptySet())
272     return ConstantRange::getEmpty(PointerSize);
273   assert(!isUnsafe(SizeRange));
274 
275   ConstantRange Offsets = offsetFrom(Addr, Base);
276   if (isUnsafe(Offsets))
277     return UnknownRange;
278 
279   Offsets = addOverflowNever(Offsets, SizeRange);
280   if (isUnsafe(Offsets))
281     return UnknownRange;
282   return Offsets;
283 }
284 
285 ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
286                                                        TypeSize Size) {
287   if (Size.isScalable())
288     return UnknownRange;
289   APInt APSize(PointerSize, Size.getFixedSize(), true);
290   if (APSize.isNegative())
291     return UnknownRange;
292   return getAccessRange(
293       Addr, Base, ConstantRange(APInt::getNullValue(PointerSize), APSize));
294 }
295 
296 ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
297     const MemIntrinsic *MI, const Use &U, Value *Base) {
298   if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
299     if (MTI->getRawSource() != U && MTI->getRawDest() != U)
300       return ConstantRange::getEmpty(PointerSize);
301   } else {
302     if (MI->getRawDest() != U)
303       return ConstantRange::getEmpty(PointerSize);
304   }
305 
306   auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
307   if (!SE.isSCEVable(MI->getLength()->getType()))
308     return UnknownRange;
309 
310   const SCEV *Expr =
311       SE.getTruncateOrZeroExtend(SE.getSCEV(MI->getLength()), CalculationTy);
312   ConstantRange Sizes = SE.getSignedRange(Expr);
313   if (Sizes.getUpper().isNegative() || isUnsafe(Sizes))
314     return UnknownRange;
315   Sizes = Sizes.sextOrTrunc(PointerSize);
316   ConstantRange SizeRange(APInt::getNullValue(PointerSize),
317                           Sizes.getUpper() - 1);
318   return getAccessRange(U, Base, SizeRange);
319 }
320 
321 /// The function analyzes all local uses of Ptr (alloca or argument) and
322 /// calculates local access range and all function calls where it was used.
323 bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
324                                               UseInfo<GlobalValue> &US,
325                                               const StackLifetime &SL) {
326   SmallPtrSet<const Value *, 16> Visited;
327   SmallVector<const Value *, 8> WorkList;
328   WorkList.push_back(Ptr);
329   const AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
330 
331   // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
332   while (!WorkList.empty()) {
333     const Value *V = WorkList.pop_back_val();
334     for (const Use &UI : V->uses()) {
335       const auto *I = cast<Instruction>(UI.getUser());
336       if (!SL.isReachable(I))
337         continue;
338 
339       assert(V == UI.get());
340 
341       switch (I->getOpcode()) {
342       case Instruction::Load: {
343         if (AI && !SL.isAliveAfter(AI, I)) {
344           US.updateRange(UnknownRange);
345           return false;
346         }
347         US.updateRange(
348             getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
349         break;
350       }
351 
352       case Instruction::VAArg:
353         // "va-arg" from a pointer is safe.
354         break;
355       case Instruction::Store: {
356         if (V == I->getOperand(0)) {
357           // Stored the pointer - conservatively assume it may be unsafe.
358           US.updateRange(UnknownRange);
359           return false;
360         }
361         if (AI && !SL.isAliveAfter(AI, I)) {
362           US.updateRange(UnknownRange);
363           return false;
364         }
365         US.updateRange(getAccessRange(
366             UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
367         break;
368       }
369 
370       case Instruction::Ret:
371         // Information leak.
372         // FIXME: Process parameters correctly. This is a leak only if we return
373         // alloca.
374         US.updateRange(UnknownRange);
375         return false;
376 
377       case Instruction::Call:
378       case Instruction::Invoke: {
379         if (I->isLifetimeStartOrEnd())
380           break;
381 
382         if (AI && !SL.isAliveAfter(AI, I)) {
383           US.updateRange(UnknownRange);
384           return false;
385         }
386 
387         if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
388           US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr));
389           break;
390         }
391 
392         const auto &CB = cast<CallBase>(*I);
393         if (!CB.isArgOperand(&UI)) {
394           US.updateRange(UnknownRange);
395           return false;
396         }
397 
398         unsigned ArgNo = CB.getArgOperandNo(&UI);
399         if (CB.isByValArgument(ArgNo)) {
400           US.updateRange(getAccessRange(
401               UI, Ptr, DL.getTypeStoreSize(CB.getParamByValType(ArgNo))));
402           break;
403         }
404 
405         // FIXME: consult devirt?
406         // Do not follow aliases, otherwise we could inadvertently follow
407         // dso_preemptable aliases or aliases with interposable linkage.
408         const GlobalValue *Callee =
409             dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
410         if (!Callee) {
411           US.updateRange(UnknownRange);
412           return false;
413         }
414 
415         assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee));
416         ConstantRange Offsets = offsetFrom(UI, Ptr);
417         auto Insert =
418             US.Calls.emplace(CallInfo<GlobalValue>(Callee, ArgNo), Offsets);
419         if (!Insert.second)
420           Insert.first->second = Insert.first->second.unionWith(Offsets);
421         break;
422       }
423 
424       default:
425         if (Visited.insert(I).second)
426           WorkList.push_back(cast<const Instruction>(I));
427       }
428     }
429   }
430 
431   return true;
432 }
433 
434 FunctionInfo<GlobalValue> StackSafetyLocalAnalysis::run() {
435   FunctionInfo<GlobalValue> Info;
436   assert(!F.isDeclaration() &&
437          "Can't run StackSafety on a function declaration");
438 
439   LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n");
440 
441   SmallVector<AllocaInst *, 64> Allocas;
442   for (auto &I : instructions(F))
443     if (auto *AI = dyn_cast<AllocaInst>(&I))
444       Allocas.push_back(AI);
445   StackLifetime SL(F, Allocas, StackLifetime::LivenessType::Must);
446   SL.run();
447 
448   for (auto *AI : Allocas) {
449     auto &UI = Info.Allocas.emplace(AI, PointerSize).first->second;
450     analyzeAllUses(AI, UI, SL);
451   }
452 
453   for (Argument &A : make_range(F.arg_begin(), F.arg_end())) {
454     // Non pointers and bypass arguments are not going to be used in any global
455     // processing.
456     if (A.getType()->isPointerTy() && !A.hasByValAttr()) {
457       auto &UI = Info.Params.emplace(A.getArgNo(), PointerSize).first->second;
458       analyzeAllUses(&A, UI, SL);
459     }
460   }
461 
462   LLVM_DEBUG(Info.print(dbgs(), F.getName(), &F));
463   LLVM_DEBUG(dbgs() << "[StackSafety] done\n");
464   return Info;
465 }
466 
467 template <typename CalleeTy> class StackSafetyDataFlowAnalysis {
468   using FunctionMap = std::map<const CalleeTy *, FunctionInfo<CalleeTy>>;
469 
470   FunctionMap Functions;
471   const ConstantRange UnknownRange;
472 
473   // Callee-to-Caller multimap.
474   DenseMap<const CalleeTy *, SmallVector<const CalleeTy *, 4>> Callers;
475   SetVector<const CalleeTy *> WorkList;
476 
477   bool updateOneUse(UseInfo<CalleeTy> &US, bool UpdateToFullSet);
478   void updateOneNode(const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS);
479   void updateOneNode(const CalleeTy *Callee) {
480     updateOneNode(Callee, Functions.find(Callee)->second);
481   }
482   void updateAllNodes() {
483     for (auto &F : Functions)
484       updateOneNode(F.first, F.second);
485   }
486   void runDataFlow();
487 #ifndef NDEBUG
488   void verifyFixedPoint();
489 #endif
490 
491 public:
492   StackSafetyDataFlowAnalysis(uint32_t PointerBitWidth, FunctionMap Functions)
493       : Functions(std::move(Functions)),
494         UnknownRange(ConstantRange::getFull(PointerBitWidth)) {}
495 
496   const FunctionMap &run();
497 
498   ConstantRange getArgumentAccessRange(const CalleeTy *Callee, unsigned ParamNo,
499                                        const ConstantRange &Offsets) const;
500 };
501 
502 template <typename CalleeTy>
503 ConstantRange StackSafetyDataFlowAnalysis<CalleeTy>::getArgumentAccessRange(
504     const CalleeTy *Callee, unsigned ParamNo,
505     const ConstantRange &Offsets) const {
506   auto FnIt = Functions.find(Callee);
507   // Unknown callee (outside of LTO domain or an indirect call).
508   if (FnIt == Functions.end())
509     return UnknownRange;
510   auto &FS = FnIt->second;
511   auto ParamIt = FS.Params.find(ParamNo);
512   if (ParamIt == FS.Params.end())
513     return UnknownRange;
514   auto &Access = ParamIt->second.Range;
515   if (Access.isEmptySet())
516     return Access;
517   if (Access.isFullSet())
518     return UnknownRange;
519   return addOverflowNever(Access, Offsets);
520 }
521 
522 template <typename CalleeTy>
523 bool StackSafetyDataFlowAnalysis<CalleeTy>::updateOneUse(UseInfo<CalleeTy> &US,
524                                                          bool UpdateToFullSet) {
525   bool Changed = false;
526   for (auto &KV : US.Calls) {
527     assert(!KV.second.isEmptySet() &&
528            "Param range can't be empty-set, invalid offset range");
529 
530     ConstantRange CalleeRange =
531         getArgumentAccessRange(KV.first.Callee, KV.first.ParamNo, KV.second);
532     if (!US.Range.contains(CalleeRange)) {
533       Changed = true;
534       if (UpdateToFullSet)
535         US.Range = UnknownRange;
536       else
537         US.updateRange(CalleeRange);
538     }
539   }
540   return Changed;
541 }
542 
543 template <typename CalleeTy>
544 void StackSafetyDataFlowAnalysis<CalleeTy>::updateOneNode(
545     const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS) {
546   bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations;
547   bool Changed = false;
548   for (auto &KV : FS.Params)
549     Changed |= updateOneUse(KV.second, UpdateToFullSet);
550 
551   if (Changed) {
552     LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount
553                       << (UpdateToFullSet ? ", full-set" : "") << "] " << &FS
554                       << "\n");
555     // Callers of this function may need updating.
556     for (auto &CallerID : Callers[Callee])
557       WorkList.insert(CallerID);
558 
559     ++FS.UpdateCount;
560   }
561 }
562 
563 template <typename CalleeTy>
564 void StackSafetyDataFlowAnalysis<CalleeTy>::runDataFlow() {
565   SmallVector<const CalleeTy *, 16> Callees;
566   for (auto &F : Functions) {
567     Callees.clear();
568     auto &FS = F.second;
569     for (auto &KV : FS.Params)
570       for (auto &CS : KV.second.Calls)
571         Callees.push_back(CS.first.Callee);
572 
573     llvm::sort(Callees);
574     Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end());
575 
576     for (auto &Callee : Callees)
577       Callers[Callee].push_back(F.first);
578   }
579 
580   updateAllNodes();
581 
582   while (!WorkList.empty()) {
583     const CalleeTy *Callee = WorkList.back();
584     WorkList.pop_back();
585     updateOneNode(Callee);
586   }
587 }
588 
589 #ifndef NDEBUG
590 template <typename CalleeTy>
591 void StackSafetyDataFlowAnalysis<CalleeTy>::verifyFixedPoint() {
592   WorkList.clear();
593   updateAllNodes();
594   assert(WorkList.empty());
595 }
596 #endif
597 
598 template <typename CalleeTy>
599 const typename StackSafetyDataFlowAnalysis<CalleeTy>::FunctionMap &
600 StackSafetyDataFlowAnalysis<CalleeTy>::run() {
601   runDataFlow();
602   LLVM_DEBUG(verifyFixedPoint());
603   return Functions;
604 }
605 
606 FunctionSummary *resolveCallee(GlobalValueSummary *S) {
607   while (S) {
608     if (!S->isLive() || !S->isDSOLocal())
609       return nullptr;
610     if (FunctionSummary *FS = dyn_cast<FunctionSummary>(S))
611       return FS;
612     AliasSummary *AS = dyn_cast<AliasSummary>(S);
613     if (!AS || !AS->hasAliasee())
614       return nullptr;
615     S = AS->getBaseObject();
616     if (S == AS)
617       return nullptr;
618   }
619   return nullptr;
620 }
621 
622 const Function *findCalleeInModule(const GlobalValue *GV) {
623   while (GV) {
624     if (GV->isDeclaration() || GV->isInterposable() || !GV->isDSOLocal())
625       return nullptr;
626     if (const Function *F = dyn_cast<Function>(GV))
627       return F;
628     const GlobalAlias *A = dyn_cast<GlobalAlias>(GV);
629     if (!A)
630       return nullptr;
631     GV = A->getBaseObject();
632     if (GV == A)
633       return nullptr;
634   }
635   return nullptr;
636 }
637 
638 GlobalValueSummary *getGlobalValueSummary(const ModuleSummaryIndex *Index,
639                                           uint64_t ValueGUID) {
640   auto VI = Index->getValueInfo(ValueGUID);
641   if (!VI || VI.getSummaryList().empty())
642     return nullptr;
643   auto &Summary = VI.getSummaryList()[0];
644   return Summary.get();
645 }
646 
647 const ConstantRange *findParamAccess(const FunctionSummary &FS,
648                                      uint32_t ParamNo) {
649   assert(FS.isLive());
650   assert(FS.isDSOLocal());
651   for (auto &PS : FS.paramAccesses())
652     if (ParamNo == PS.ParamNo)
653       return &PS.Use;
654   return nullptr;
655 }
656 
657 void resolveAllCalls(UseInfo<GlobalValue> &Use,
658                      const ModuleSummaryIndex *Index) {
659   ConstantRange FullSet(Use.Range.getBitWidth(), true);
660   auto TmpCalls = std::move(Use.Calls);
661   for (const auto &C : TmpCalls) {
662     const Function *F = findCalleeInModule(C.first.Callee);
663     if (F) {
664       Use.Calls.emplace(CallInfo<GlobalValue>(F, C.first.ParamNo), C.second);
665       continue;
666     }
667 
668     if (!Index)
669       return Use.updateRange(FullSet);
670     GlobalValueSummary *GVS =
671         getGlobalValueSummary(Index, C.first.Callee->getGUID());
672 
673     FunctionSummary *FS = resolveCallee(GVS);
674     ++NumModuleCalleeLookupTotal;
675     if (!FS) {
676       ++NumModuleCalleeLookupFailed;
677       return Use.updateRange(FullSet);
678     }
679     const ConstantRange *Found = findParamAccess(*FS, C.first.ParamNo);
680     if (!Found || Found->isFullSet())
681       return Use.updateRange(FullSet);
682     ConstantRange Access = Found->sextOrTrunc(Use.Range.getBitWidth());
683     if (!Access.isEmptySet())
684       Use.updateRange(addOverflowNever(Access, C.second));
685   }
686 }
687 
688 GVToSSI createGlobalStackSafetyInfo(
689     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions,
690     const ModuleSummaryIndex *Index) {
691   GVToSSI SSI;
692   if (Functions.empty())
693     return SSI;
694 
695   // FIXME: Simplify printing and remove copying here.
696   auto Copy = Functions;
697 
698   for (auto &FnKV : Copy)
699     for (auto &KV : FnKV.second.Params) {
700       resolveAllCalls(KV.second, Index);
701       if (KV.second.Range.isFullSet())
702         KV.second.Calls.clear();
703     }
704 
705   uint32_t PointerSize = Copy.begin()
706                              ->first->getParent()
707                              ->getDataLayout()
708                              .getMaxPointerSizeInBits();
709   StackSafetyDataFlowAnalysis<GlobalValue> SSDFA(PointerSize, std::move(Copy));
710 
711   for (auto &F : SSDFA.run()) {
712     auto FI = F.second;
713     auto &SrcF = Functions[F.first];
714     for (auto &KV : FI.Allocas) {
715       auto &A = KV.second;
716       resolveAllCalls(A, Index);
717       for (auto &C : A.Calls) {
718         A.updateRange(SSDFA.getArgumentAccessRange(C.first.Callee,
719                                                    C.first.ParamNo, C.second));
720       }
721       // FIXME: This is needed only to preserve calls in print() results.
722       A.Calls = SrcF.Allocas.find(KV.first)->second.Calls;
723     }
724     for (auto &KV : FI.Params) {
725       auto &P = KV.second;
726       P.Calls = SrcF.Params.find(KV.first)->second.Calls;
727     }
728     SSI[F.first] = std::move(FI);
729   }
730 
731   return SSI;
732 }
733 
734 } // end anonymous namespace
735 
736 StackSafetyInfo::StackSafetyInfo() = default;
737 
738 StackSafetyInfo::StackSafetyInfo(Function *F,
739                                  std::function<ScalarEvolution &()> GetSE)
740     : F(F), GetSE(GetSE) {}
741 
742 StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default;
743 
744 StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default;
745 
746 StackSafetyInfo::~StackSafetyInfo() = default;
747 
748 const StackSafetyInfo::InfoTy &StackSafetyInfo::getInfo() const {
749   if (!Info) {
750     StackSafetyLocalAnalysis SSLA(*F, GetSE());
751     Info.reset(new InfoTy{SSLA.run()});
752   }
753   return *Info;
754 }
755 
756 void StackSafetyInfo::print(raw_ostream &O) const {
757   getInfo().Info.print(O, F->getName(), dyn_cast<Function>(F));
758 }
759 
760 const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
761   if (!Info) {
762     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions;
763     for (auto &F : M->functions()) {
764       if (!F.isDeclaration()) {
765         auto FI = GetSSI(F).getInfo().Info;
766         Functions.emplace(&F, std::move(FI));
767       }
768     }
769     Info.reset(new InfoTy{
770         createGlobalStackSafetyInfo(std::move(Functions), Index), {}});
771     for (auto &FnKV : Info->Info) {
772       for (auto &KV : FnKV.second.Allocas) {
773         ++NumAllocaTotal;
774         const AllocaInst *AI = KV.first;
775         if (getStaticAllocaSizeRange(*AI).contains(KV.second.Range)) {
776           Info->SafeAllocas.insert(AI);
777           ++NumAllocaStackSafe;
778         }
779       }
780     }
781     if (StackSafetyPrint)
782       print(errs());
783   }
784   return *Info;
785 }
786 
787 std::vector<FunctionSummary::ParamAccess>
788 StackSafetyInfo::getParamAccesses() const {
789   // Implementation transforms internal representation of parameter information
790   // into FunctionSummary format.
791   std::vector<FunctionSummary::ParamAccess> ParamAccesses;
792   for (const auto &KV : getInfo().Info.Params) {
793     auto &PS = KV.second;
794     // Parameter accessed by any or unknown offset, represented as FullSet by
795     // StackSafety, is handled as the parameter for which we have no
796     // StackSafety info at all. So drop it to reduce summary size.
797     if (PS.Range.isFullSet())
798       continue;
799 
800     ParamAccesses.emplace_back(KV.first, PS.Range);
801     FunctionSummary::ParamAccess &Param = ParamAccesses.back();
802 
803     Param.Calls.reserve(PS.Calls.size());
804     for (auto &C : PS.Calls) {
805       // Parameter forwarded into another function by any or unknown offset
806       // will make ParamAccess::Range as FullSet anyway. So we can drop the
807       // entire parameter like we did above.
808       // TODO(vitalybuka): Return already filtered parameters from getInfo().
809       if (C.second.isFullSet()) {
810         ParamAccesses.pop_back();
811         break;
812       }
813       Param.Calls.emplace_back(C.first.ParamNo, C.first.Callee->getGUID(),
814                                C.second);
815       llvm::sort(Param.Calls, [](const FunctionSummary::ParamAccess::Call &L,
816                                  const FunctionSummary::ParamAccess::Call &R) {
817         return std::tie(L.ParamNo, L.Callee) < std::tie(R.ParamNo, R.Callee);
818       });
819     }
820   }
821   return ParamAccesses;
822 }
823 
824 StackSafetyGlobalInfo::StackSafetyGlobalInfo() = default;
825 
826 StackSafetyGlobalInfo::StackSafetyGlobalInfo(
827     Module *M, std::function<const StackSafetyInfo &(Function &F)> GetSSI,
828     const ModuleSummaryIndex *Index)
829     : M(M), GetSSI(GetSSI), Index(Index) {
830   if (StackSafetyRun)
831     getInfo();
832 }
833 
834 StackSafetyGlobalInfo::StackSafetyGlobalInfo(StackSafetyGlobalInfo &&) =
835     default;
836 
837 StackSafetyGlobalInfo &
838 StackSafetyGlobalInfo::operator=(StackSafetyGlobalInfo &&) = default;
839 
840 StackSafetyGlobalInfo::~StackSafetyGlobalInfo() = default;
841 
842 bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
843   const auto &Info = getInfo();
844   return Info.SafeAllocas.count(&AI);
845 }
846 
847 void StackSafetyGlobalInfo::print(raw_ostream &O) const {
848   auto &SSI = getInfo().Info;
849   if (SSI.empty())
850     return;
851   const Module &M = *SSI.begin()->first->getParent();
852   for (auto &F : M.functions()) {
853     if (!F.isDeclaration()) {
854       SSI.find(&F)->second.print(O, F.getName(), &F);
855       O << "\n";
856     }
857   }
858 }
859 
860 LLVM_DUMP_METHOD void StackSafetyGlobalInfo::dump() const { print(dbgs()); }
861 
862 AnalysisKey StackSafetyAnalysis::Key;
863 
864 StackSafetyInfo StackSafetyAnalysis::run(Function &F,
865                                          FunctionAnalysisManager &AM) {
866   return StackSafetyInfo(&F, [&AM, &F]() -> ScalarEvolution & {
867     return AM.getResult<ScalarEvolutionAnalysis>(F);
868   });
869 }
870 
871 PreservedAnalyses StackSafetyPrinterPass::run(Function &F,
872                                               FunctionAnalysisManager &AM) {
873   OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n";
874   AM.getResult<StackSafetyAnalysis>(F).print(OS);
875   return PreservedAnalyses::all();
876 }
877 
878 char StackSafetyInfoWrapperPass::ID = 0;
879 
880 StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) {
881   initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
882 }
883 
884 void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
885   AU.addRequiredTransitive<ScalarEvolutionWrapperPass>();
886   AU.setPreservesAll();
887 }
888 
889 void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const {
890   SSI.print(O);
891 }
892 
893 bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) {
894   auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
895   SSI = {&F, [SE]() -> ScalarEvolution & { return *SE; }};
896   return false;
897 }
898 
899 AnalysisKey StackSafetyGlobalAnalysis::Key;
900 
901 StackSafetyGlobalInfo
902 StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
903   // FIXME: Lookup Module Summary.
904   FunctionAnalysisManager &FAM =
905       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
906   return {&M,
907           [&FAM](Function &F) -> const StackSafetyInfo & {
908             return FAM.getResult<StackSafetyAnalysis>(F);
909           },
910           nullptr};
911 }
912 
913 PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
914                                                     ModuleAnalysisManager &AM) {
915   OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n";
916   AM.getResult<StackSafetyGlobalAnalysis>(M).print(OS);
917   return PreservedAnalyses::all();
918 }
919 
920 char StackSafetyGlobalInfoWrapperPass::ID = 0;
921 
922 StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
923     : ModulePass(ID) {
924   initializeStackSafetyGlobalInfoWrapperPassPass(
925       *PassRegistry::getPassRegistry());
926 }
927 
928 StackSafetyGlobalInfoWrapperPass::~StackSafetyGlobalInfoWrapperPass() = default;
929 
930 void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
931                                              const Module *M) const {
932   SSGI.print(O);
933 }
934 
935 void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
936     AnalysisUsage &AU) const {
937   AU.setPreservesAll();
938   AU.addRequired<StackSafetyInfoWrapperPass>();
939 }
940 
941 bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
942   const ModuleSummaryIndex *ImportSummary = nullptr;
943   if (auto *IndexWrapperPass =
944           getAnalysisIfAvailable<ImmutableModuleSummaryIndexWrapperPass>())
945     ImportSummary = IndexWrapperPass->getIndex();
946 
947   SSGI = {&M,
948           [this](Function &F) -> const StackSafetyInfo & {
949             return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
950           },
951           ImportSummary};
952   return false;
953 }
954 
955 bool llvm::needsParamAccessSummary(const Module &M) {
956   if (StackSafetyRun)
957     return true;
958   for (auto &F : M.functions())
959     if (F.hasFnAttribute(Attribute::SanitizeMemTag))
960       return true;
961   return false;
962 }
963 
964 void llvm::generateParamAccessSummary(ModuleSummaryIndex &Index) {
965   if (!Index.hasParamAccess())
966     return;
967   const ConstantRange FullSet(FunctionSummary::ParamAccess::RangeWidth, true);
968 
969   auto CountParamAccesses = [&](auto &Stat) {
970     if (!AreStatisticsEnabled())
971       return;
972     for (auto &GVS : Index)
973       for (auto &GV : GVS.second.SummaryList)
974         if (FunctionSummary *FS = dyn_cast<FunctionSummary>(GV.get()))
975           Stat += FS->paramAccesses().size();
976   };
977 
978   CountParamAccesses(NumCombinedParamAccessesBefore);
979 
980   std::map<const FunctionSummary *, FunctionInfo<FunctionSummary>> Functions;
981 
982   // Convert the ModuleSummaryIndex to a FunctionMap
983   for (auto &GVS : Index) {
984     for (auto &GV : GVS.second.SummaryList) {
985       FunctionSummary *FS = dyn_cast<FunctionSummary>(GV.get());
986       if (!FS || FS->paramAccesses().empty())
987         continue;
988       if (FS->isLive() && FS->isDSOLocal()) {
989         FunctionInfo<FunctionSummary> FI;
990         for (auto &PS : FS->paramAccesses()) {
991           auto &US =
992               FI.Params
993                   .emplace(PS.ParamNo, FunctionSummary::ParamAccess::RangeWidth)
994                   .first->second;
995           US.Range = PS.Use;
996           for (auto &Call : PS.Calls) {
997             assert(!Call.Offsets.isFullSet());
998             FunctionSummary *S = resolveCallee(
999                 Index.findSummaryInModule(Call.Callee, FS->modulePath()));
1000             ++NumCombinedCalleeLookupTotal;
1001             if (!S) {
1002               ++NumCombinedCalleeLookupFailed;
1003               US.Range = FullSet;
1004               US.Calls.clear();
1005               break;
1006             }
1007             US.Calls.emplace(CallInfo<FunctionSummary>(S, Call.ParamNo),
1008                              Call.Offsets);
1009           }
1010         }
1011         Functions.emplace(FS, std::move(FI));
1012       }
1013       // Reset data for all summaries. Alive and DSO local will be set back from
1014       // of data flow results below. Anything else will not be accessed
1015       // by ThinLTO backend, so we can save on bitcode size.
1016       FS->setParamAccesses({});
1017     }
1018   }
1019   NumCombinedDataFlowNodes += Functions.size();
1020   StackSafetyDataFlowAnalysis<FunctionSummary> SSDFA(
1021       FunctionSummary::ParamAccess::RangeWidth, std::move(Functions));
1022   for (auto &KV : SSDFA.run()) {
1023     std::vector<FunctionSummary::ParamAccess> NewParams;
1024     NewParams.reserve(KV.second.Params.size());
1025     for (auto &Param : KV.second.Params) {
1026       // It's not needed as FullSet is processed the same as a missing value.
1027       if (Param.second.Range.isFullSet())
1028         continue;
1029       NewParams.emplace_back();
1030       FunctionSummary::ParamAccess &New = NewParams.back();
1031       New.ParamNo = Param.first;
1032       New.Use = Param.second.Range; // Only range is needed.
1033     }
1034     const_cast<FunctionSummary *>(KV.first)->setParamAccesses(
1035         std::move(NewParams));
1036   }
1037 
1038   CountParamAccesses(NumCombinedParamAccessesAfter);
1039 }
1040 
1041 static const char LocalPassArg[] = "stack-safety-local";
1042 static const char LocalPassName[] = "Stack Safety Local Analysis";
1043 INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
1044                       false, true)
1045 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
1046 INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
1047                     false, true)
1048 
1049 static const char GlobalPassName[] = "Stack Safety Analysis";
1050 INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
1051                       GlobalPassName, false, true)
1052 INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass)
1053 INITIALIZE_PASS_DEPENDENCY(ImmutableModuleSummaryIndexWrapperPass)
1054 INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
1055                     GlobalPassName, false, true)
1056