1 //===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "llvm/Analysis/StackSafetyAnalysis.h"
12 #include "llvm/ADT/APInt.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/ModuleSummaryAnalysis.h"
17 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
18 #include "llvm/Analysis/StackLifetime.h"
19 #include "llvm/IR/ConstantRange.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/GlobalValue.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <memory>
32 
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "stack-safety"
36 
37 STATISTIC(NumAllocaStackSafe, "Number of safe allocas");
38 STATISTIC(NumAllocaTotal, "Number of total allocas");
39 
40 STATISTIC(NumCombinedCalleeLookupTotal,
41           "Number of total callee lookups on combined index.");
42 STATISTIC(NumCombinedCalleeLookupFailed,
43           "Number of failed callee lookups on combined index.");
44 STATISTIC(NumModuleCalleeLookupTotal,
45           "Number of total callee lookups on module index.");
46 STATISTIC(NumModuleCalleeLookupFailed,
47           "Number of failed callee lookups on module index.");
48 STATISTIC(NumCombinedParamAccessesBefore,
49           "Number of total param accesses before generateParamAccessSummary.");
50 STATISTIC(NumCombinedParamAccessesAfter,
51           "Number of total param accesses after generateParamAccessSummary.");
52 
53 static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations",
54                                              cl::init(20), cl::Hidden);
55 
56 static cl::opt<bool> StackSafetyPrint("stack-safety-print", cl::init(false),
57                                       cl::Hidden);
58 
59 static cl::opt<bool> StackSafetyRun("stack-safety-run", cl::init(false),
60                                     cl::Hidden);
61 
62 namespace {
63 
64 /// Describes use of address in as a function call argument.
65 template <typename CalleeTy> struct CallInfo {
66   /// Function being called.
67   const CalleeTy *Callee = nullptr;
68   /// Index of argument which pass address.
69   size_t ParamNo = 0;
70   // Offset range of address from base address (alloca or calling function
71   // argument).
72   // Range should never set to empty-set, that is an invalid access range
73   // that can cause empty-set to be propagated with ConstantRange::add
74   ConstantRange Offset;
75   CallInfo(const CalleeTy *Callee, size_t ParamNo, const ConstantRange &Offset)
76       : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {}
77 };
78 
79 template <typename CalleeTy>
80 raw_ostream &operator<<(raw_ostream &OS, const CallInfo<CalleeTy> &P) {
81   return OS << "@" << P.Callee->getName() << "(arg" << P.ParamNo << ", "
82             << P.Offset << ")";
83 }
84 
85 /// Describe uses of address (alloca or parameter) inside of the function.
86 template <typename CalleeTy> struct UseInfo {
87   // Access range if the address (alloca or parameters).
88   // It is allowed to be empty-set when there are no known accesses.
89   ConstantRange Range;
90 
91   // List of calls which pass address as an argument.
92   SmallVector<CallInfo<CalleeTy>, 4> Calls;
93 
94   UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
95 
96   void updateRange(const ConstantRange &R) {
97     assert(!R.isUpperSignWrapped());
98     Range = Range.unionWith(R);
99     assert(!Range.isUpperSignWrapped());
100   }
101 };
102 
103 template <typename CalleeTy>
104 raw_ostream &operator<<(raw_ostream &OS, const UseInfo<CalleeTy> &U) {
105   OS << U.Range;
106   for (auto &Call : U.Calls)
107     OS << ", " << Call;
108   return OS;
109 }
110 
111 // Check if we should bailout for such ranges.
112 bool isUnsafe(const ConstantRange &R) {
113   return R.isEmptySet() || R.isFullSet() || R.isUpperSignWrapped();
114 }
115 
116 ConstantRange addOverflowNever(const ConstantRange &L, const ConstantRange &R) {
117   if (L.signedAddMayOverflow(R) !=
118       ConstantRange::OverflowResult::NeverOverflows)
119     return ConstantRange(L.getBitWidth(), true);
120   return L.add(R);
121 }
122 
123 /// Calculate the allocation size of a given alloca. Returns empty range
124 // in case of confution.
125 ConstantRange getStaticAllocaSizeRange(const AllocaInst &AI) {
126   const DataLayout &DL = AI.getModule()->getDataLayout();
127   TypeSize TS = DL.getTypeAllocSize(AI.getAllocatedType());
128   unsigned PointerSize = DL.getMaxPointerSizeInBits();
129   // Fallback to empty range for alloca size.
130   ConstantRange R = ConstantRange::getEmpty(PointerSize);
131   if (TS.isScalable())
132     return R;
133   APInt APSize(PointerSize, TS.getFixedSize(), true);
134   if (APSize.isNonPositive())
135     return R;
136   if (AI.isArrayAllocation()) {
137     const auto *C = dyn_cast<ConstantInt>(AI.getArraySize());
138     if (!C)
139       return R;
140     bool Overflow = false;
141     APInt Mul = C->getValue();
142     if (Mul.isNonPositive())
143       return R;
144     Mul = Mul.sextOrTrunc(PointerSize);
145     APSize = APSize.smul_ov(Mul, Overflow);
146     if (Overflow)
147       return R;
148   }
149   R = ConstantRange(APInt::getNullValue(PointerSize), APSize);
150   assert(!isUnsafe(R));
151   return R;
152 }
153 
154 template <typename CalleeTy> struct FunctionInfo {
155   std::map<const AllocaInst *, UseInfo<CalleeTy>> Allocas;
156   std::map<uint32_t, UseInfo<CalleeTy>> Params;
157   // TODO: describe return value as depending on one or more of its arguments.
158 
159   // StackSafetyDataFlowAnalysis counter stored here for faster access.
160   int UpdateCount = 0;
161 
162   void print(raw_ostream &O, StringRef Name, const Function *F) const {
163     // TODO: Consider different printout format after
164     // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then.
165     O << "  @" << Name << ((F && F->isDSOLocal()) ? "" : " dso_preemptable")
166       << ((F && F->isInterposable()) ? " interposable" : "") << "\n";
167 
168     O << "    args uses:\n";
169     for (auto &KV : Params) {
170       O << "      ";
171       if (F)
172         O << F->getArg(KV.first)->getName();
173       else
174         O << formatv("arg{0}", KV.first);
175       O << "[]: " << KV.second << "\n";
176     }
177 
178     O << "    allocas uses:\n";
179     if (F) {
180       for (auto &I : instructions(F)) {
181         if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
182           auto &AS = Allocas.find(AI)->second;
183           O << "      " << AI->getName() << "["
184             << getStaticAllocaSizeRange(*AI).getUpper() << "]: " << AS << "\n";
185         }
186       }
187     } else {
188       assert(Allocas.empty());
189     }
190     O << "\n";
191   }
192 };
193 
194 using GVToSSI = std::map<const GlobalValue *, FunctionInfo<GlobalValue>>;
195 
196 } // namespace
197 
198 struct StackSafetyInfo::InfoTy {
199   FunctionInfo<GlobalValue> Info;
200 };
201 
202 struct StackSafetyGlobalInfo::InfoTy {
203   GVToSSI Info;
204   SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
205 };
206 
207 namespace {
208 
209 class StackSafetyLocalAnalysis {
210   Function &F;
211   const DataLayout &DL;
212   ScalarEvolution &SE;
213   unsigned PointerSize = 0;
214 
215   const ConstantRange UnknownRange;
216 
217   ConstantRange offsetFrom(Value *Addr, Value *Base);
218   ConstantRange getAccessRange(Value *Addr, Value *Base,
219                                const ConstantRange &SizeRange);
220   ConstantRange getAccessRange(Value *Addr, Value *Base, TypeSize Size);
221   ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U,
222                                            Value *Base);
223 
224   bool analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS,
225                       const StackLifetime &SL);
226 
227 public:
228   StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
229       : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
230         PointerSize(DL.getPointerSizeInBits()),
231         UnknownRange(PointerSize, true) {}
232 
233   // Run the transformation on the associated function.
234   FunctionInfo<GlobalValue> run();
235 };
236 
237 ConstantRange StackSafetyLocalAnalysis::offsetFrom(Value *Addr, Value *Base) {
238   if (!SE.isSCEVable(Addr->getType()) || !SE.isSCEVable(Base->getType()))
239     return UnknownRange;
240 
241   auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
242   const SCEV *AddrExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Addr), PtrTy);
243   const SCEV *BaseExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Base), PtrTy);
244   const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
245 
246   ConstantRange Offset = SE.getSignedRange(Diff);
247   if (isUnsafe(Offset))
248     return UnknownRange;
249   return Offset.sextOrTrunc(PointerSize);
250 }
251 
252 ConstantRange
253 StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
254                                          const ConstantRange &SizeRange) {
255   // Zero-size loads and stores do not access memory.
256   if (SizeRange.isEmptySet())
257     return ConstantRange::getEmpty(PointerSize);
258   assert(!isUnsafe(SizeRange));
259 
260   ConstantRange Offsets = offsetFrom(Addr, Base);
261   if (isUnsafe(Offsets))
262     return UnknownRange;
263 
264   Offsets = addOverflowNever(Offsets, SizeRange);
265   if (isUnsafe(Offsets))
266     return UnknownRange;
267   return Offsets;
268 }
269 
270 ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
271                                                        TypeSize Size) {
272   if (Size.isScalable())
273     return UnknownRange;
274   APInt APSize(PointerSize, Size.getFixedSize(), true);
275   if (APSize.isNegative())
276     return UnknownRange;
277   return getAccessRange(
278       Addr, Base, ConstantRange(APInt::getNullValue(PointerSize), APSize));
279 }
280 
281 ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
282     const MemIntrinsic *MI, const Use &U, Value *Base) {
283   if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
284     if (MTI->getRawSource() != U && MTI->getRawDest() != U)
285       return ConstantRange::getEmpty(PointerSize);
286   } else {
287     if (MI->getRawDest() != U)
288       return ConstantRange::getEmpty(PointerSize);
289   }
290 
291   auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
292   if (!SE.isSCEVable(MI->getLength()->getType()))
293     return UnknownRange;
294 
295   const SCEV *Expr =
296       SE.getTruncateOrZeroExtend(SE.getSCEV(MI->getLength()), CalculationTy);
297   ConstantRange Sizes = SE.getSignedRange(Expr);
298   if (Sizes.getUpper().isNegative() || isUnsafe(Sizes))
299     return UnknownRange;
300   Sizes = Sizes.sextOrTrunc(PointerSize);
301   ConstantRange SizeRange(APInt::getNullValue(PointerSize),
302                           Sizes.getUpper() - 1);
303   return getAccessRange(U, Base, SizeRange);
304 }
305 
306 /// The function analyzes all local uses of Ptr (alloca or argument) and
307 /// calculates local access range and all function calls where it was used.
308 bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
309                                               UseInfo<GlobalValue> &US,
310                                               const StackLifetime &SL) {
311   SmallPtrSet<const Value *, 16> Visited;
312   SmallVector<const Value *, 8> WorkList;
313   WorkList.push_back(Ptr);
314   const AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
315 
316   // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
317   while (!WorkList.empty()) {
318     const Value *V = WorkList.pop_back_val();
319     for (const Use &UI : V->uses()) {
320       const auto *I = cast<Instruction>(UI.getUser());
321       if (!SL.isReachable(I))
322         continue;
323 
324       assert(V == UI.get());
325 
326       switch (I->getOpcode()) {
327       case Instruction::Load: {
328         if (AI && !SL.isAliveAfter(AI, I)) {
329           US.updateRange(UnknownRange);
330           return false;
331         }
332         US.updateRange(
333             getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
334         break;
335       }
336 
337       case Instruction::VAArg:
338         // "va-arg" from a pointer is safe.
339         break;
340       case Instruction::Store: {
341         if (V == I->getOperand(0)) {
342           // Stored the pointer - conservatively assume it may be unsafe.
343           US.updateRange(UnknownRange);
344           return false;
345         }
346         if (AI && !SL.isAliveAfter(AI, I)) {
347           US.updateRange(UnknownRange);
348           return false;
349         }
350         US.updateRange(getAccessRange(
351             UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
352         break;
353       }
354 
355       case Instruction::Ret:
356         // Information leak.
357         // FIXME: Process parameters correctly. This is a leak only if we return
358         // alloca.
359         US.updateRange(UnknownRange);
360         return false;
361 
362       case Instruction::Call:
363       case Instruction::Invoke: {
364         if (I->isLifetimeStartOrEnd())
365           break;
366 
367         if (AI && !SL.isAliveAfter(AI, I)) {
368           US.updateRange(UnknownRange);
369           return false;
370         }
371 
372         if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
373           US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr));
374           break;
375         }
376 
377         const auto &CB = cast<CallBase>(*I);
378         if (!CB.isArgOperand(&UI)) {
379           US.updateRange(UnknownRange);
380           return false;
381         }
382 
383         unsigned ArgNo = CB.getArgOperandNo(&UI);
384         if (CB.isByValArgument(ArgNo)) {
385           US.updateRange(getAccessRange(
386               UI, Ptr, DL.getTypeStoreSize(CB.getParamByValType(ArgNo))));
387           break;
388         }
389 
390         // FIXME: consult devirt?
391         // Do not follow aliases, otherwise we could inadvertently follow
392         // dso_preemptable aliases or aliases with interposable linkage.
393         const GlobalValue *Callee =
394             dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
395         if (!Callee) {
396           US.updateRange(UnknownRange);
397           return false;
398         }
399 
400         assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee));
401         US.Calls.emplace_back(Callee, ArgNo, offsetFrom(UI, Ptr));
402         break;
403       }
404 
405       default:
406         if (Visited.insert(I).second)
407           WorkList.push_back(cast<const Instruction>(I));
408       }
409     }
410   }
411 
412   return true;
413 }
414 
415 FunctionInfo<GlobalValue> StackSafetyLocalAnalysis::run() {
416   FunctionInfo<GlobalValue> Info;
417   assert(!F.isDeclaration() &&
418          "Can't run StackSafety on a function declaration");
419 
420   LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n");
421 
422   SmallVector<AllocaInst *, 64> Allocas;
423   for (auto &I : instructions(F))
424     if (auto *AI = dyn_cast<AllocaInst>(&I))
425       Allocas.push_back(AI);
426   StackLifetime SL(F, Allocas, StackLifetime::LivenessType::Must);
427   SL.run();
428 
429   for (auto *AI : Allocas) {
430     auto &UI = Info.Allocas.emplace(AI, PointerSize).first->second;
431     analyzeAllUses(AI, UI, SL);
432   }
433 
434   for (Argument &A : make_range(F.arg_begin(), F.arg_end())) {
435     // Non pointers and bypass arguments are not going to be used in any global
436     // processing.
437     if (A.getType()->isPointerTy() && !A.hasByValAttr()) {
438       auto &UI = Info.Params.emplace(A.getArgNo(), PointerSize).first->second;
439       analyzeAllUses(&A, UI, SL);
440     }
441   }
442 
443   LLVM_DEBUG(Info.print(dbgs(), F.getName(), &F));
444   LLVM_DEBUG(dbgs() << "[StackSafety] done\n");
445   return Info;
446 }
447 
448 template <typename CalleeTy> class StackSafetyDataFlowAnalysis {
449   using FunctionMap = std::map<const CalleeTy *, FunctionInfo<CalleeTy>>;
450 
451   FunctionMap Functions;
452   const ConstantRange UnknownRange;
453 
454   // Callee-to-Caller multimap.
455   DenseMap<const CalleeTy *, SmallVector<const CalleeTy *, 4>> Callers;
456   SetVector<const CalleeTy *> WorkList;
457 
458   bool updateOneUse(UseInfo<CalleeTy> &US, bool UpdateToFullSet);
459   void updateOneNode(const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS);
460   void updateOneNode(const CalleeTy *Callee) {
461     updateOneNode(Callee, Functions.find(Callee)->second);
462   }
463   void updateAllNodes() {
464     for (auto &F : Functions)
465       updateOneNode(F.first, F.second);
466   }
467   void runDataFlow();
468 #ifndef NDEBUG
469   void verifyFixedPoint();
470 #endif
471 
472 public:
473   StackSafetyDataFlowAnalysis(uint32_t PointerBitWidth, FunctionMap Functions)
474       : Functions(std::move(Functions)),
475         UnknownRange(ConstantRange::getFull(PointerBitWidth)) {}
476 
477   const FunctionMap &run();
478 
479   ConstantRange getArgumentAccessRange(const CalleeTy *Callee, unsigned ParamNo,
480                                        const ConstantRange &Offsets) const;
481 };
482 
483 template <typename CalleeTy>
484 ConstantRange StackSafetyDataFlowAnalysis<CalleeTy>::getArgumentAccessRange(
485     const CalleeTy *Callee, unsigned ParamNo,
486     const ConstantRange &Offsets) const {
487   auto FnIt = Functions.find(Callee);
488   // Unknown callee (outside of LTO domain or an indirect call).
489   if (FnIt == Functions.end())
490     return UnknownRange;
491   auto &FS = FnIt->second;
492   auto ParamIt = FS.Params.find(ParamNo);
493   if (ParamIt == FS.Params.end())
494     return UnknownRange;
495   auto &Access = ParamIt->second.Range;
496   if (Access.isEmptySet())
497     return Access;
498   if (Access.isFullSet())
499     return UnknownRange;
500   return addOverflowNever(Access, Offsets);
501 }
502 
503 template <typename CalleeTy>
504 bool StackSafetyDataFlowAnalysis<CalleeTy>::updateOneUse(UseInfo<CalleeTy> &US,
505                                                          bool UpdateToFullSet) {
506   bool Changed = false;
507   for (auto &CS : US.Calls) {
508     assert(!CS.Offset.isEmptySet() &&
509            "Param range can't be empty-set, invalid offset range");
510 
511     ConstantRange CalleeRange =
512         getArgumentAccessRange(CS.Callee, CS.ParamNo, CS.Offset);
513     if (!US.Range.contains(CalleeRange)) {
514       Changed = true;
515       if (UpdateToFullSet)
516         US.Range = UnknownRange;
517       else
518         US.Range = US.Range.unionWith(CalleeRange);
519     }
520   }
521   return Changed;
522 }
523 
524 template <typename CalleeTy>
525 void StackSafetyDataFlowAnalysis<CalleeTy>::updateOneNode(
526     const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS) {
527   bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations;
528   bool Changed = false;
529   for (auto &KV : FS.Params)
530     Changed |= updateOneUse(KV.second, UpdateToFullSet);
531 
532   if (Changed) {
533     LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount
534                       << (UpdateToFullSet ? ", full-set" : "") << "] " << &FS
535                       << "\n");
536     // Callers of this function may need updating.
537     for (auto &CallerID : Callers[Callee])
538       WorkList.insert(CallerID);
539 
540     ++FS.UpdateCount;
541   }
542 }
543 
544 template <typename CalleeTy>
545 void StackSafetyDataFlowAnalysis<CalleeTy>::runDataFlow() {
546   SmallVector<const CalleeTy *, 16> Callees;
547   for (auto &F : Functions) {
548     Callees.clear();
549     auto &FS = F.second;
550     for (auto &KV : FS.Params)
551       for (auto &CS : KV.second.Calls)
552         Callees.push_back(CS.Callee);
553 
554     llvm::sort(Callees);
555     Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end());
556 
557     for (auto &Callee : Callees)
558       Callers[Callee].push_back(F.first);
559   }
560 
561   updateAllNodes();
562 
563   while (!WorkList.empty()) {
564     const CalleeTy *Callee = WorkList.back();
565     WorkList.pop_back();
566     updateOneNode(Callee);
567   }
568 }
569 
570 #ifndef NDEBUG
571 template <typename CalleeTy>
572 void StackSafetyDataFlowAnalysis<CalleeTy>::verifyFixedPoint() {
573   WorkList.clear();
574   updateAllNodes();
575   assert(WorkList.empty());
576 }
577 #endif
578 
579 template <typename CalleeTy>
580 const typename StackSafetyDataFlowAnalysis<CalleeTy>::FunctionMap &
581 StackSafetyDataFlowAnalysis<CalleeTy>::run() {
582   runDataFlow();
583   LLVM_DEBUG(verifyFixedPoint());
584   return Functions;
585 }
586 
587 FunctionSummary *resolveCallee(GlobalValueSummary *S) {
588   while (S) {
589     if (!S->isLive() || !S->isDSOLocal())
590       return nullptr;
591     if (FunctionSummary *FS = dyn_cast<FunctionSummary>(S))
592       return FS;
593     AliasSummary *AS = dyn_cast<AliasSummary>(S);
594     if (!AS)
595       return nullptr;
596     S = AS->getBaseObject();
597     if (S == AS)
598       return nullptr;
599   }
600   return nullptr;
601 }
602 
603 const Function *findCalleeInModule(const GlobalValue *GV) {
604   while (GV) {
605     if (GV->isDeclaration() || GV->isInterposable() || !GV->isDSOLocal())
606       return nullptr;
607     if (const Function *F = dyn_cast<Function>(GV))
608       return F;
609     const GlobalAlias *A = dyn_cast<GlobalAlias>(GV);
610     if (!A)
611       return nullptr;
612     GV = A->getBaseObject();
613     if (GV == A)
614       return nullptr;
615   }
616   return nullptr;
617 }
618 
619 GlobalValueSummary *getGlobalValueSummary(const ModuleSummaryIndex *Index,
620                                           uint64_t ValueGUID) {
621   auto VI = Index->getValueInfo(ValueGUID);
622   if (!VI || VI.getSummaryList().empty())
623     return nullptr;
624   auto &Summary = VI.getSummaryList()[0];
625   return Summary.get();
626 }
627 
628 const ConstantRange *findParamAccess(const FunctionSummary &FS,
629                                      uint32_t ParamNo) {
630   assert(FS.isLive());
631   assert(FS.isDSOLocal());
632   for (auto &PS : FS.paramAccesses())
633     if (ParamNo == PS.ParamNo)
634       return &PS.Use;
635   return nullptr;
636 }
637 
638 void resolveAllCalls(UseInfo<GlobalValue> &Use,
639                      const ModuleSummaryIndex *Index) {
640   ConstantRange FullSet(Use.Range.getBitWidth(), true);
641   for (auto &C : Use.Calls) {
642     const Function *F = findCalleeInModule(C.Callee);
643     if (F) {
644       C.Callee = F;
645       continue;
646     }
647 
648     if (!Index)
649       return Use.updateRange(FullSet);
650     GlobalValueSummary *GVS = getGlobalValueSummary(Index, C.Callee->getGUID());
651 
652     FunctionSummary *FS = resolveCallee(GVS);
653     ++NumModuleCalleeLookupTotal;
654     if (!FS) {
655       ++NumModuleCalleeLookupFailed;
656       return Use.updateRange(FullSet);
657     }
658     const ConstantRange *Found = findParamAccess(*FS, C.ParamNo);
659     if (!Found)
660       return Use.updateRange(FullSet);
661     ConstantRange Access = Found->sextOrTrunc(Use.Range.getBitWidth());
662     Use.updateRange(addOverflowNever(Access, C.Offset));
663     C.Callee = nullptr;
664   }
665 
666   Use.Calls.erase(std::remove_if(Use.Calls.begin(), Use.Calls.end(),
667                                  [](auto &T) { return !T.Callee; }),
668                   Use.Calls.end());
669 }
670 
671 GVToSSI createGlobalStackSafetyInfo(
672     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions,
673     const ModuleSummaryIndex *Index) {
674   GVToSSI SSI;
675   if (Functions.empty())
676     return SSI;
677 
678   // FIXME: Simplify printing and remove copying here.
679   auto Copy = Functions;
680 
681   for (auto &FnKV : Copy)
682     for (auto &KV : FnKV.second.Params)
683       resolveAllCalls(KV.second, Index);
684 
685   uint32_t PointerSize = Copy.begin()
686                              ->first->getParent()
687                              ->getDataLayout()
688                              .getMaxPointerSizeInBits();
689   StackSafetyDataFlowAnalysis<GlobalValue> SSDFA(PointerSize, std::move(Copy));
690 
691   for (auto &F : SSDFA.run()) {
692     auto FI = F.second;
693     auto &SrcF = Functions[F.first];
694     for (auto &KV : FI.Allocas) {
695       auto &A = KV.second;
696       resolveAllCalls(A, Index);
697       for (auto &C : A.Calls) {
698         A.updateRange(
699             SSDFA.getArgumentAccessRange(C.Callee, C.ParamNo, C.Offset));
700       }
701       // FIXME: This is needed only to preserve calls in print() results.
702       A.Calls = SrcF.Allocas.find(KV.first)->second.Calls;
703     }
704     for (auto &KV : FI.Params) {
705       auto &P = KV.second;
706       P.Calls = SrcF.Params.find(KV.first)->second.Calls;
707     }
708     SSI[F.first] = std::move(FI);
709   }
710 
711   return SSI;
712 }
713 
714 } // end anonymous namespace
715 
716 StackSafetyInfo::StackSafetyInfo() = default;
717 
718 StackSafetyInfo::StackSafetyInfo(Function *F,
719                                  std::function<ScalarEvolution &()> GetSE)
720     : F(F), GetSE(GetSE) {}
721 
722 StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default;
723 
724 StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default;
725 
726 StackSafetyInfo::~StackSafetyInfo() = default;
727 
728 const StackSafetyInfo::InfoTy &StackSafetyInfo::getInfo() const {
729   if (!Info) {
730     StackSafetyLocalAnalysis SSLA(*F, GetSE());
731     Info.reset(new InfoTy{SSLA.run()});
732   }
733   return *Info;
734 }
735 
736 void StackSafetyInfo::print(raw_ostream &O) const {
737   getInfo().Info.print(O, F->getName(), dyn_cast<Function>(F));
738 }
739 
740 const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
741   if (!Info) {
742     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions;
743     for (auto &F : M->functions()) {
744       if (!F.isDeclaration()) {
745         auto FI = GetSSI(F).getInfo().Info;
746         Functions.emplace(&F, std::move(FI));
747       }
748     }
749     Info.reset(new InfoTy{
750         createGlobalStackSafetyInfo(std::move(Functions), Index), {}});
751     for (auto &FnKV : Info->Info) {
752       for (auto &KV : FnKV.second.Allocas) {
753         ++NumAllocaTotal;
754         const AllocaInst *AI = KV.first;
755         if (getStaticAllocaSizeRange(*AI).contains(KV.second.Range)) {
756           Info->SafeAllocas.insert(AI);
757           ++NumAllocaStackSafe;
758         }
759       }
760     }
761     if (StackSafetyPrint)
762       print(errs());
763   }
764   return *Info;
765 }
766 
767 std::vector<FunctionSummary::ParamAccess>
768 StackSafetyInfo::getParamAccesses() const {
769   // Implementation transforms internal representation of parameter information
770   // into FunctionSummary format.
771   std::vector<FunctionSummary::ParamAccess> ParamAccesses;
772   for (const auto &KV : getInfo().Info.Params) {
773     auto &PS = KV.second;
774     // Parameter accessed by any or unknown offset, represented as FullSet by
775     // StackSafety, is handled as the parameter for which we have no
776     // StackSafety info at all. So drop it to reduce summary size.
777     if (PS.Range.isFullSet())
778       continue;
779 
780     ParamAccesses.emplace_back(KV.first, PS.Range);
781     FunctionSummary::ParamAccess &Param = ParamAccesses.back();
782 
783     Param.Calls.reserve(PS.Calls.size());
784     for (auto &C : PS.Calls) {
785       // Parameter forwarded into another function by any or unknown offset
786       // will make ParamAccess::Range as FullSet anyway. So we can drop the
787       // entire parameter like we did above.
788       // TODO(vitalybuka): Return already filtered parameters from getInfo().
789       if (C.Offset.isFullSet()) {
790         ParamAccesses.pop_back();
791         break;
792       }
793       Param.Calls.emplace_back(C.ParamNo, C.Callee->getGUID(), C.Offset);
794     }
795   }
796   return ParamAccesses;
797 }
798 
799 StackSafetyGlobalInfo::StackSafetyGlobalInfo() = default;
800 
801 StackSafetyGlobalInfo::StackSafetyGlobalInfo(
802     Module *M, std::function<const StackSafetyInfo &(Function &F)> GetSSI,
803     const ModuleSummaryIndex *Index)
804     : M(M), GetSSI(GetSSI), Index(Index) {
805   if (StackSafetyRun)
806     getInfo();
807 }
808 
809 StackSafetyGlobalInfo::StackSafetyGlobalInfo(StackSafetyGlobalInfo &&) =
810     default;
811 
812 StackSafetyGlobalInfo &
813 StackSafetyGlobalInfo::operator=(StackSafetyGlobalInfo &&) = default;
814 
815 StackSafetyGlobalInfo::~StackSafetyGlobalInfo() = default;
816 
817 bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
818   const auto &Info = getInfo();
819   return Info.SafeAllocas.count(&AI);
820 }
821 
822 void StackSafetyGlobalInfo::print(raw_ostream &O) const {
823   auto &SSI = getInfo().Info;
824   if (SSI.empty())
825     return;
826   const Module &M = *SSI.begin()->first->getParent();
827   for (auto &F : M.functions()) {
828     if (!F.isDeclaration()) {
829       SSI.find(&F)->second.print(O, F.getName(), &F);
830       O << "\n";
831     }
832   }
833 }
834 
835 LLVM_DUMP_METHOD void StackSafetyGlobalInfo::dump() const { print(dbgs()); }
836 
837 AnalysisKey StackSafetyAnalysis::Key;
838 
839 StackSafetyInfo StackSafetyAnalysis::run(Function &F,
840                                          FunctionAnalysisManager &AM) {
841   return StackSafetyInfo(&F, [&AM, &F]() -> ScalarEvolution & {
842     return AM.getResult<ScalarEvolutionAnalysis>(F);
843   });
844 }
845 
846 PreservedAnalyses StackSafetyPrinterPass::run(Function &F,
847                                               FunctionAnalysisManager &AM) {
848   OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n";
849   AM.getResult<StackSafetyAnalysis>(F).print(OS);
850   return PreservedAnalyses::all();
851 }
852 
853 char StackSafetyInfoWrapperPass::ID = 0;
854 
855 StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) {
856   initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
857 }
858 
859 void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
860   AU.addRequiredTransitive<ScalarEvolutionWrapperPass>();
861   AU.setPreservesAll();
862 }
863 
864 void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const {
865   SSI.print(O);
866 }
867 
868 bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) {
869   auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
870   SSI = {&F, [SE]() -> ScalarEvolution & { return *SE; }};
871   return false;
872 }
873 
874 AnalysisKey StackSafetyGlobalAnalysis::Key;
875 
876 StackSafetyGlobalInfo
877 StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
878   // FIXME: Lookup Module Summary.
879   FunctionAnalysisManager &FAM =
880       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
881   return {&M,
882           [&FAM](Function &F) -> const StackSafetyInfo & {
883             return FAM.getResult<StackSafetyAnalysis>(F);
884           },
885           nullptr};
886 }
887 
888 PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
889                                                     ModuleAnalysisManager &AM) {
890   OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n";
891   AM.getResult<StackSafetyGlobalAnalysis>(M).print(OS);
892   return PreservedAnalyses::all();
893 }
894 
895 char StackSafetyGlobalInfoWrapperPass::ID = 0;
896 
897 StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
898     : ModulePass(ID) {
899   initializeStackSafetyGlobalInfoWrapperPassPass(
900       *PassRegistry::getPassRegistry());
901 }
902 
903 StackSafetyGlobalInfoWrapperPass::~StackSafetyGlobalInfoWrapperPass() = default;
904 
905 void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
906                                              const Module *M) const {
907   SSGI.print(O);
908 }
909 
910 void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
911     AnalysisUsage &AU) const {
912   AU.setPreservesAll();
913   AU.addRequired<StackSafetyInfoWrapperPass>();
914 }
915 
916 bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
917   const ModuleSummaryIndex *ImportSummary = nullptr;
918   if (auto *IndexWrapperPass =
919           getAnalysisIfAvailable<ImmutableModuleSummaryIndexWrapperPass>())
920     ImportSummary = IndexWrapperPass->getIndex();
921 
922   SSGI = {&M,
923           [this](Function &F) -> const StackSafetyInfo & {
924             return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
925           },
926           ImportSummary};
927   return false;
928 }
929 
930 bool llvm::needsParamAccessSummary(const Module &M) {
931   if (StackSafetyRun)
932     return true;
933   for (auto &F : M.functions())
934     if (F.hasFnAttribute(Attribute::SanitizeMemTag))
935       return true;
936   return false;
937 }
938 
939 void llvm::generateParamAccessSummary(ModuleSummaryIndex &Index) {
940   if (!Index.hasParamAccess())
941     return;
942   const ConstantRange FullSet(FunctionSummary::ParamAccess::RangeWidth, true);
943 
944   auto CountParamAccesses = [&](auto &Stat) {
945     if (!AreStatisticsEnabled())
946       return;
947     for (auto &GVS : Index)
948       for (auto &GV : GVS.second.SummaryList)
949         if (FunctionSummary *FS = dyn_cast<FunctionSummary>(GV.get()))
950           Stat += FS->paramAccesses().size();
951   };
952 
953   CountParamAccesses(NumCombinedParamAccessesBefore);
954 
955   std::map<const FunctionSummary *, FunctionInfo<FunctionSummary>> Functions;
956 
957   // Convert the ModuleSummaryIndex to a FunctionMap
958   for (auto &GVS : Index) {
959     for (auto &GV : GVS.second.SummaryList) {
960       FunctionSummary *FS = dyn_cast<FunctionSummary>(GV.get());
961       if (!FS || FS->paramAccesses().empty())
962         continue;
963       if (FS->isLive() && FS->isDSOLocal()) {
964         FunctionInfo<FunctionSummary> FI;
965         for (auto &PS : FS->paramAccesses()) {
966           auto &US =
967               FI.Params
968                   .emplace(PS.ParamNo, FunctionSummary::ParamAccess::RangeWidth)
969                   .first->second;
970           US.Range = PS.Use;
971           for (auto &Call : PS.Calls) {
972             assert(!Call.Offsets.isFullSet());
973             FunctionSummary *S = resolveCallee(
974                 Index.findSummaryInModule(Call.Callee, FS->modulePath()));
975             ++NumCombinedCalleeLookupTotal;
976             if (!S) {
977               ++NumCombinedCalleeLookupFailed;
978               US.Range = FullSet;
979               US.Calls.clear();
980               break;
981             }
982             US.Calls.emplace_back(S, Call.ParamNo, Call.Offsets);
983           }
984         }
985         Functions.emplace(FS, std::move(FI));
986       }
987       // Reset data for all summaries. Alive and DSO local will be set back from
988       // of data flow results below. Anything else will not be accessed
989       // by ThinLTO backend, so we can save on bitcode size.
990       FS->setParamAccesses({});
991     }
992   }
993   StackSafetyDataFlowAnalysis<FunctionSummary> SSDFA(
994       FunctionSummary::ParamAccess::RangeWidth, std::move(Functions));
995   for (auto &KV : SSDFA.run()) {
996     std::vector<FunctionSummary::ParamAccess> NewParams;
997     NewParams.reserve(KV.second.Params.size());
998     for (auto &Param : KV.second.Params) {
999       NewParams.emplace_back();
1000       FunctionSummary::ParamAccess &New = NewParams.back();
1001       New.ParamNo = Param.first;
1002       New.Use = Param.second.Range; // Only range is needed.
1003     }
1004     const_cast<FunctionSummary *>(KV.first)->setParamAccesses(
1005         std::move(NewParams));
1006   }
1007 
1008   CountParamAccesses(NumCombinedParamAccessesAfter);
1009 }
1010 
1011 static const char LocalPassArg[] = "stack-safety-local";
1012 static const char LocalPassName[] = "Stack Safety Local Analysis";
1013 INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
1014                       false, true)
1015 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
1016 INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
1017                     false, true)
1018 
1019 static const char GlobalPassName[] = "Stack Safety Analysis";
1020 INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
1021                       GlobalPassName, false, true)
1022 INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass)
1023 INITIALIZE_PASS_DEPENDENCY(ImmutableModuleSummaryIndexWrapperPass)
1024 INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
1025                     GlobalPassName, false, true)
1026