1 //===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "llvm/Analysis/StackSafetyAnalysis.h"
12 #include "llvm/ADT/APInt.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/Statistic.h"
15 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
16 #include "llvm/IR/ConstantRange.h"
17 #include "llvm/IR/DerivedTypes.h"
18 #include "llvm/IR/GlobalValue.h"
19 #include "llvm/IR/InstIterator.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/IntrinsicInst.h"
22 #include "llvm/InitializePasses.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include <algorithm>
28 #include <memory>
29 
30 using namespace llvm;
31 
32 #define DEBUG_TYPE "stack-safety"
33 
34 STATISTIC(NumAllocaStackSafe, "Number of safe allocas");
35 STATISTIC(NumAllocaTotal, "Number of total allocas");
36 
37 static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations",
38                                              cl::init(20), cl::Hidden);
39 
40 static cl::opt<bool> StackSafetyPrint("stack-safety-print", cl::init(false),
41                                       cl::Hidden);
42 
43 static cl::opt<bool> StackSafetyRun("stack-safety-run", cl::init(false),
44                                     cl::Hidden);
45 
46 namespace {
47 
48 /// Describes use of address in as a function call argument.
49 template <typename CalleeTy> struct CallInfo {
50   /// Function being called.
51   const CalleeTy *Callee = nullptr;
52   /// Index of argument which pass address.
53   size_t ParamNo = 0;
54   // Offset range of address from base address (alloca or calling function
55   // argument).
56   // Range should never set to empty-set, that is an invalid access range
57   // that can cause empty-set to be propagated with ConstantRange::add
58   ConstantRange Offset;
59   CallInfo(const CalleeTy *Callee, size_t ParamNo, ConstantRange Offset)
60       : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {}
61 };
62 
63 template <typename CalleeTy>
64 raw_ostream &operator<<(raw_ostream &OS, const CallInfo<CalleeTy> &P) {
65   return OS << "@" << P.Callee->getName() << "(arg" << P.ParamNo << ", "
66             << P.Offset << ")";
67 }
68 
69 /// Describe uses of address (alloca or parameter) inside of the function.
70 template <typename CalleeTy> struct UseInfo {
71   // Access range if the address (alloca or parameters).
72   // It is allowed to be empty-set when there are no known accesses.
73   ConstantRange Range;
74 
75   // List of calls which pass address as an argument.
76   SmallVector<CallInfo<CalleeTy>, 4> Calls;
77 
78   UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
79 
80   void updateRange(const ConstantRange &R) {
81     assert(!R.isUpperSignWrapped());
82     Range = Range.unionWith(R);
83     assert(!Range.isUpperSignWrapped());
84   }
85 };
86 
87 template <typename CalleeTy>
88 raw_ostream &operator<<(raw_ostream &OS, const UseInfo<CalleeTy> &U) {
89   OS << U.Range;
90   for (auto &Call : U.Calls)
91     OS << ", " << Call;
92   return OS;
93 }
94 
95 // Check if we should bailout for such ranges.
96 bool isUnsafe(const ConstantRange &R) {
97   return R.isEmptySet() || R.isFullSet() || R.isUpperSignWrapped();
98 }
99 
100 /// Calculate the allocation size of a given alloca. Returns empty range
101 // in case of confution.
102 ConstantRange getStaticAllocaSizeRange(const AllocaInst &AI) {
103   const DataLayout &DL = AI.getModule()->getDataLayout();
104   TypeSize TS = DL.getTypeAllocSize(AI.getAllocatedType());
105   unsigned PointerSize = DL.getMaxPointerSizeInBits();
106   // Fallback to empty range for alloca size.
107   ConstantRange R = ConstantRange::getEmpty(PointerSize);
108   if (TS.isScalable())
109     return R;
110   APInt APSize(PointerSize, TS.getFixedSize(), true);
111   if (APSize.isNonPositive())
112     return R;
113   if (AI.isArrayAllocation()) {
114     const auto *C = dyn_cast<ConstantInt>(AI.getArraySize());
115     if (!C)
116       return R;
117     bool Overflow = false;
118     APInt Mul = C->getValue();
119     if (Mul.isNonPositive())
120       return R;
121     Mul = Mul.sextOrTrunc(PointerSize);
122     APSize = APSize.smul_ov(Mul, Overflow);
123     if (Overflow)
124       return R;
125   }
126   R = ConstantRange(APInt::getNullValue(PointerSize), APSize);
127   assert(!isUnsafe(R));
128   return R;
129 }
130 
131 template <typename CalleeTy> struct FunctionInfo {
132   std::map<const AllocaInst *, UseInfo<CalleeTy>> Allocas;
133   std::map<uint32_t, UseInfo<CalleeTy>> Params;
134   // TODO: describe return value as depending on one or more of its arguments.
135 
136   // StackSafetyDataFlowAnalysis counter stored here for faster access.
137   int UpdateCount = 0;
138 
139   void print(raw_ostream &O, StringRef Name, const Function *F) const {
140     // TODO: Consider different printout format after
141     // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then.
142     O << "  @" << Name << ((F && F->isDSOLocal()) ? "" : " dso_preemptable")
143       << ((F && F->isInterposable()) ? " interposable" : "") << "\n";
144 
145     O << "    args uses:\n";
146     for (auto &KV : Params) {
147       O << "      ";
148       if (F)
149         O << F->getArg(KV.first)->getName();
150       else
151         O << formatv("arg{0}", KV.first);
152       O << "[]: " << KV.second << "\n";
153     }
154 
155     O << "    allocas uses:\n";
156     if (F) {
157       for (auto &I : instructions(F)) {
158         if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
159           auto &AS = Allocas.find(AI)->second;
160           O << "      " << AI->getName() << "["
161             << getStaticAllocaSizeRange(*AI).getUpper() << "]: " << AS << "\n";
162         }
163       }
164     } else {
165       assert(Allocas.empty());
166     }
167   }
168 };
169 
170 using GVToSSI = std::map<const GlobalValue *, FunctionInfo<GlobalValue>>;
171 
172 } // namespace
173 
174 struct StackSafetyInfo::InfoTy {
175   FunctionInfo<GlobalValue> Info;
176 };
177 
178 struct StackSafetyGlobalInfo::InfoTy {
179   GVToSSI Info;
180   SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
181 };
182 
183 namespace {
184 
185 class StackSafetyLocalAnalysis {
186   Function &F;
187   const DataLayout &DL;
188   ScalarEvolution &SE;
189   unsigned PointerSize = 0;
190 
191   const ConstantRange UnknownRange;
192 
193   ConstantRange offsetFrom(Value *Addr, Value *Base);
194   ConstantRange getAccessRange(Value *Addr, Value *Base,
195                                ConstantRange SizeRange);
196   ConstantRange getAccessRange(Value *Addr, Value *Base, TypeSize Size);
197   ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U,
198                                            Value *Base);
199 
200   bool analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS);
201 
202 public:
203   StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
204       : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
205         PointerSize(DL.getPointerSizeInBits()),
206         UnknownRange(PointerSize, true) {}
207 
208   // Run the transformation on the associated function.
209   FunctionInfo<GlobalValue> run();
210 };
211 
212 ConstantRange StackSafetyLocalAnalysis::offsetFrom(Value *Addr, Value *Base) {
213   if (!SE.isSCEVable(Addr->getType()) || !SE.isSCEVable(Base->getType()))
214     return UnknownRange;
215 
216   auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
217   const SCEV *AddrExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Addr), PtrTy);
218   const SCEV *BaseExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Base), PtrTy);
219   const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
220 
221   ConstantRange Offset = SE.getSignedRange(Diff);
222   if (isUnsafe(Offset))
223     return UnknownRange;
224   return Offset.sextOrTrunc(PointerSize);
225 }
226 
227 ConstantRange
228 StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
229                                          ConstantRange SizeRange) {
230   // Zero-size loads and stores do not access memory.
231   if (SizeRange.isEmptySet())
232     return ConstantRange::getEmpty(PointerSize);
233   assert(!isUnsafe(SizeRange));
234 
235   ConstantRange Offsets = offsetFrom(Addr, Base);
236   if (isUnsafe(Offsets))
237     return UnknownRange;
238 
239   if (Offsets.signedAddMayOverflow(SizeRange) !=
240       ConstantRange::OverflowResult::NeverOverflows)
241     return UnknownRange;
242   Offsets = Offsets.add(SizeRange);
243   if (isUnsafe(Offsets))
244     return UnknownRange;
245   return Offsets;
246 }
247 
248 ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
249                                                        TypeSize Size) {
250   if (Size.isScalable())
251     return UnknownRange;
252   APInt APSize(PointerSize, Size.getFixedSize(), true);
253   if (APSize.isNegative())
254     return UnknownRange;
255   return getAccessRange(
256       Addr, Base, ConstantRange(APInt::getNullValue(PointerSize), APSize));
257 }
258 
259 ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
260     const MemIntrinsic *MI, const Use &U, Value *Base) {
261   if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
262     if (MTI->getRawSource() != U && MTI->getRawDest() != U)
263       return ConstantRange::getEmpty(PointerSize);
264   } else {
265     if (MI->getRawDest() != U)
266       return ConstantRange::getEmpty(PointerSize);
267   }
268 
269   auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
270   if (!SE.isSCEVable(MI->getLength()->getType()))
271     return UnknownRange;
272 
273   const SCEV *Expr =
274       SE.getTruncateOrZeroExtend(SE.getSCEV(MI->getLength()), CalculationTy);
275   ConstantRange Sizes = SE.getSignedRange(Expr);
276   if (Sizes.getUpper().isNegative() || isUnsafe(Sizes))
277     return UnknownRange;
278   Sizes = Sizes.sextOrTrunc(PointerSize);
279   ConstantRange SizeRange(APInt::getNullValue(PointerSize),
280                           Sizes.getUpper() - 1);
281   return getAccessRange(U, Base, SizeRange);
282 }
283 
284 /// The function analyzes all local uses of Ptr (alloca or argument) and
285 /// calculates local access range and all function calls where it was used.
286 bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
287                                               UseInfo<GlobalValue> &US) {
288   SmallPtrSet<const Value *, 16> Visited;
289   SmallVector<const Value *, 8> WorkList;
290   WorkList.push_back(Ptr);
291 
292   // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
293   while (!WorkList.empty()) {
294     const Value *V = WorkList.pop_back_val();
295     for (const Use &UI : V->uses()) {
296       const auto *I = cast<const Instruction>(UI.getUser());
297       assert(V == UI.get());
298 
299       switch (I->getOpcode()) {
300       case Instruction::Load: {
301         US.updateRange(
302             getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
303         break;
304       }
305 
306       case Instruction::VAArg:
307         // "va-arg" from a pointer is safe.
308         break;
309       case Instruction::Store: {
310         if (V == I->getOperand(0)) {
311           // Stored the pointer - conservatively assume it may be unsafe.
312           US.updateRange(UnknownRange);
313           return false;
314         }
315         US.updateRange(getAccessRange(
316             UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
317         break;
318       }
319 
320       case Instruction::Ret:
321         // Information leak.
322         // FIXME: Process parameters correctly. This is a leak only if we return
323         // alloca.
324         US.updateRange(UnknownRange);
325         return false;
326 
327       case Instruction::Call:
328       case Instruction::Invoke: {
329         const auto &CB = cast<CallBase>(*I);
330 
331         if (I->isLifetimeStartOrEnd())
332           break;
333 
334         if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
335           US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr));
336           break;
337         }
338 
339         // FIXME: consult devirt?
340         // Do not follow aliases, otherwise we could inadvertently follow
341         // dso_preemptable aliases or aliases with interposable linkage.
342         const GlobalValue *Callee =
343             dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
344         if (!Callee) {
345           US.updateRange(UnknownRange);
346           return false;
347         }
348 
349         assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee));
350 
351         int Found = 0;
352         for (size_t ArgNo = 0; ArgNo < CB.getNumArgOperands(); ++ArgNo) {
353           if (CB.getArgOperand(ArgNo) == V) {
354             ++Found;
355             US.Calls.emplace_back(Callee, ArgNo, offsetFrom(UI, Ptr));
356           }
357         }
358         if (!Found) {
359           US.updateRange(UnknownRange);
360           return false;
361         }
362 
363         break;
364       }
365 
366       default:
367         if (Visited.insert(I).second)
368           WorkList.push_back(cast<const Instruction>(I));
369       }
370     }
371   }
372 
373   return true;
374 }
375 
376 FunctionInfo<GlobalValue> StackSafetyLocalAnalysis::run() {
377   FunctionInfo<GlobalValue> Info;
378   assert(!F.isDeclaration() &&
379          "Can't run StackSafety on a function declaration");
380 
381   LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n");
382 
383   for (auto &I : instructions(F)) {
384     if (auto *AI = dyn_cast<AllocaInst>(&I)) {
385       auto &UI = Info.Allocas.emplace(AI, PointerSize).first->second;
386       analyzeAllUses(AI, UI);
387     }
388   }
389 
390   for (Argument &A : make_range(F.arg_begin(), F.arg_end())) {
391     if (A.getType()->isPointerTy()) {
392       auto &UI = Info.Params.emplace(A.getArgNo(), PointerSize).first->second;
393       analyzeAllUses(&A, UI);
394     }
395   }
396 
397   LLVM_DEBUG(Info.print(dbgs(), F.getName(), &F));
398   LLVM_DEBUG(dbgs() << "[StackSafety] done\n");
399   return Info;
400 }
401 
402 template <typename CalleeTy> class StackSafetyDataFlowAnalysis {
403   using FunctionMap = std::map<const CalleeTy *, FunctionInfo<CalleeTy>>;
404 
405   FunctionMap Functions;
406   const ConstantRange UnknownRange;
407 
408   // Callee-to-Caller multimap.
409   DenseMap<const CalleeTy *, SmallVector<const CalleeTy *, 4>> Callers;
410   SetVector<const CalleeTy *> WorkList;
411 
412   bool updateOneUse(UseInfo<CalleeTy> &US, bool UpdateToFullSet);
413   void updateOneNode(const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS);
414   void updateOneNode(const CalleeTy *Callee) {
415     updateOneNode(Callee, Functions.find(Callee)->second);
416   }
417   void updateAllNodes() {
418     for (auto &F : Functions)
419       updateOneNode(F.first, F.second);
420   }
421   void runDataFlow();
422 #ifndef NDEBUG
423   void verifyFixedPoint();
424 #endif
425 
426 public:
427   StackSafetyDataFlowAnalysis(uint32_t PointerBitWidth, FunctionMap Functions)
428       : Functions(std::move(Functions)),
429         UnknownRange(ConstantRange::getFull(PointerBitWidth)) {}
430 
431   const FunctionMap &run();
432 
433   ConstantRange getArgumentAccessRange(const CalleeTy *Callee, unsigned ParamNo,
434                                        const ConstantRange &Offsets) const;
435 };
436 
437 template <typename CalleeTy>
438 ConstantRange StackSafetyDataFlowAnalysis<CalleeTy>::getArgumentAccessRange(
439     const CalleeTy *Callee, unsigned ParamNo,
440     const ConstantRange &Offsets) const {
441   auto FnIt = Functions.find(Callee);
442   // Unknown callee (outside of LTO domain or an indirect call).
443   if (FnIt == Functions.end())
444     return UnknownRange;
445   auto &FS = FnIt->second;
446   auto ParamIt = FS.Params.find(ParamNo);
447   if (ParamIt == FS.Params.end())
448     return UnknownRange;
449   auto &Access = ParamIt->second.Range;
450   if (Access.isEmptySet())
451     return Access;
452   if (Access.isFullSet())
453     return UnknownRange;
454   if (Offsets.signedAddMayOverflow(Access) !=
455       ConstantRange::OverflowResult::NeverOverflows)
456     return UnknownRange;
457   return Access.add(Offsets);
458 }
459 
460 template <typename CalleeTy>
461 bool StackSafetyDataFlowAnalysis<CalleeTy>::updateOneUse(UseInfo<CalleeTy> &US,
462                                                          bool UpdateToFullSet) {
463   bool Changed = false;
464   for (auto &CS : US.Calls) {
465     assert(!CS.Offset.isEmptySet() &&
466            "Param range can't be empty-set, invalid offset range");
467 
468     ConstantRange CalleeRange =
469         getArgumentAccessRange(CS.Callee, CS.ParamNo, CS.Offset);
470     if (!US.Range.contains(CalleeRange)) {
471       Changed = true;
472       if (UpdateToFullSet)
473         US.Range = UnknownRange;
474       else
475         US.Range = US.Range.unionWith(CalleeRange);
476     }
477   }
478   return Changed;
479 }
480 
481 template <typename CalleeTy>
482 void StackSafetyDataFlowAnalysis<CalleeTy>::updateOneNode(
483     const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS) {
484   bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations;
485   bool Changed = false;
486   for (auto &KV : FS.Params)
487     Changed |= updateOneUse(KV.second, UpdateToFullSet);
488 
489   if (Changed) {
490     LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount
491                       << (UpdateToFullSet ? ", full-set" : "") << "] " << &FS
492                       << "\n");
493     // Callers of this function may need updating.
494     for (auto &CallerID : Callers[Callee])
495       WorkList.insert(CallerID);
496 
497     ++FS.UpdateCount;
498   }
499 }
500 
501 template <typename CalleeTy>
502 void StackSafetyDataFlowAnalysis<CalleeTy>::runDataFlow() {
503   SmallVector<const CalleeTy *, 16> Callees;
504   for (auto &F : Functions) {
505     Callees.clear();
506     auto &FS = F.second;
507     for (auto &KV : FS.Params)
508       for (auto &CS : KV.second.Calls)
509         Callees.push_back(CS.Callee);
510 
511     llvm::sort(Callees);
512     Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end());
513 
514     for (auto &Callee : Callees)
515       Callers[Callee].push_back(F.first);
516   }
517 
518   updateAllNodes();
519 
520   while (!WorkList.empty()) {
521     const CalleeTy *Callee = WorkList.back();
522     WorkList.pop_back();
523     updateOneNode(Callee);
524   }
525 }
526 
527 #ifndef NDEBUG
528 template <typename CalleeTy>
529 void StackSafetyDataFlowAnalysis<CalleeTy>::verifyFixedPoint() {
530   WorkList.clear();
531   updateAllNodes();
532   assert(WorkList.empty());
533 }
534 #endif
535 
536 template <typename CalleeTy>
537 const typename StackSafetyDataFlowAnalysis<CalleeTy>::FunctionMap &
538 StackSafetyDataFlowAnalysis<CalleeTy>::run() {
539   runDataFlow();
540   LLVM_DEBUG(verifyFixedPoint());
541   return Functions;
542 }
543 
544 const Function *findCalleeInModule(const GlobalValue *GV) {
545   while (GV) {
546     if (GV->isDeclaration() || GV->isInterposable() || !GV->isDSOLocal())
547       return nullptr;
548     if (const Function *F = dyn_cast<Function>(GV))
549       return F;
550     const GlobalAlias *A = dyn_cast<GlobalAlias>(GV);
551     if (!A)
552       return nullptr;
553     GV = A->getBaseObject();
554     if (GV == A)
555       return nullptr;
556   }
557   return nullptr;
558 }
559 
560 template <typename CalleeTy> void resolveAllCalls(UseInfo<CalleeTy> &Use) {
561   ConstantRange FullSet(Use.Range.getBitWidth(), true);
562   for (auto &C : Use.Calls) {
563     const Function *F = findCalleeInModule(C.Callee);
564     if (F) {
565       C.Callee = F;
566       continue;
567     }
568 
569     return Use.updateRange(FullSet);
570   }
571 }
572 
573 GVToSSI createGlobalStackSafetyInfo(
574     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions) {
575   GVToSSI SSI;
576   if (Functions.empty())
577     return SSI;
578 
579   // FIXME: Simplify printing and remove copying here.
580   auto Copy = Functions;
581 
582   for (auto &FnKV : Copy)
583     for (auto &KV : FnKV.second.Params)
584       resolveAllCalls(KV.second);
585 
586   uint32_t PointerSize = Copy.begin()
587                              ->first->getParent()
588                              ->getDataLayout()
589                              .getMaxPointerSizeInBits();
590   StackSafetyDataFlowAnalysis<GlobalValue> SSDFA(PointerSize, std::move(Copy));
591 
592   for (auto &F : SSDFA.run()) {
593     auto FI = F.second;
594     auto &SrcF = Functions[F.first];
595     for (auto &KV : FI.Allocas) {
596       auto &A = KV.second;
597       resolveAllCalls(A);
598       for (auto &C : A.Calls) {
599         A.updateRange(
600             SSDFA.getArgumentAccessRange(C.Callee, C.ParamNo, C.Offset));
601       }
602       // FIXME: This is needed only to preserve calls in print() results.
603       A.Calls = SrcF.Allocas.find(KV.first)->second.Calls;
604     }
605     for (auto &KV : FI.Params) {
606       auto &P = KV.second;
607       P.Calls = SrcF.Params.find(KV.first)->second.Calls;
608     }
609     SSI[F.first] = std::move(FI);
610   }
611 
612   return SSI;
613 }
614 
615 } // end anonymous namespace
616 
617 StackSafetyInfo::StackSafetyInfo() = default;
618 
619 StackSafetyInfo::StackSafetyInfo(Function *F,
620                                  std::function<ScalarEvolution &()> GetSE)
621     : F(F), GetSE(GetSE) {}
622 
623 StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default;
624 
625 StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default;
626 
627 StackSafetyInfo::~StackSafetyInfo() = default;
628 
629 const StackSafetyInfo::InfoTy &StackSafetyInfo::getInfo() const {
630   if (!Info) {
631     StackSafetyLocalAnalysis SSLA(*F, GetSE());
632     Info.reset(new InfoTy{SSLA.run()});
633   }
634   return *Info;
635 }
636 
637 void StackSafetyInfo::print(raw_ostream &O) const {
638   getInfo().Info.print(O, F->getName(), dyn_cast<Function>(F));
639 }
640 
641 const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
642   if (!Info) {
643     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions;
644     for (auto &F : M->functions()) {
645       if (!F.isDeclaration()) {
646         auto FI = GetSSI(F).getInfo().Info;
647         Functions.emplace(&F, std::move(FI));
648       }
649     }
650     Info.reset(
651         new InfoTy{createGlobalStackSafetyInfo(std::move(Functions)), {}});
652     for (auto &FnKV : Info->Info) {
653       for (auto &KV : FnKV.second.Allocas) {
654         ++NumAllocaTotal;
655         const AllocaInst *AI = KV.first;
656         if (getStaticAllocaSizeRange(*AI).contains(KV.second.Range)) {
657           Info->SafeAllocas.insert(AI);
658           ++NumAllocaStackSafe;
659         }
660       }
661     }
662     if (StackSafetyPrint)
663       print(errs());
664   }
665   return *Info;
666 }
667 
668 StackSafetyGlobalInfo::StackSafetyGlobalInfo() = default;
669 
670 StackSafetyGlobalInfo::StackSafetyGlobalInfo(
671     Module *M, std::function<const StackSafetyInfo &(Function &F)> GetSSI)
672     : M(M), GetSSI(GetSSI) {
673   if (StackSafetyRun)
674     getInfo();
675 }
676 
677 StackSafetyGlobalInfo::StackSafetyGlobalInfo(StackSafetyGlobalInfo &&) =
678     default;
679 
680 StackSafetyGlobalInfo &
681 StackSafetyGlobalInfo::operator=(StackSafetyGlobalInfo &&) = default;
682 
683 StackSafetyGlobalInfo::~StackSafetyGlobalInfo() = default;
684 
685 bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
686   const auto &Info = getInfo();
687   return Info.SafeAllocas.count(&AI);
688 }
689 
690 void StackSafetyGlobalInfo::print(raw_ostream &O) const {
691   auto &SSI = getInfo().Info;
692   if (SSI.empty())
693     return;
694   const Module &M = *SSI.begin()->first->getParent();
695   for (auto &F : M.functions()) {
696     if (!F.isDeclaration()) {
697       SSI.find(&F)->second.print(O, F.getName(), &F);
698       O << "\n";
699     }
700   }
701 }
702 
703 LLVM_DUMP_METHOD void StackSafetyGlobalInfo::dump() const { print(dbgs()); }
704 
705 AnalysisKey StackSafetyAnalysis::Key;
706 
707 StackSafetyInfo StackSafetyAnalysis::run(Function &F,
708                                          FunctionAnalysisManager &AM) {
709   return StackSafetyInfo(&F, [&AM, &F]() -> ScalarEvolution & {
710     return AM.getResult<ScalarEvolutionAnalysis>(F);
711   });
712 }
713 
714 PreservedAnalyses StackSafetyPrinterPass::run(Function &F,
715                                               FunctionAnalysisManager &AM) {
716   OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n";
717   AM.getResult<StackSafetyAnalysis>(F).print(OS);
718   return PreservedAnalyses::all();
719 }
720 
721 char StackSafetyInfoWrapperPass::ID = 0;
722 
723 StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) {
724   initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
725 }
726 
727 void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
728   AU.addRequiredTransitive<ScalarEvolutionWrapperPass>();
729   AU.setPreservesAll();
730 }
731 
732 void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const {
733   SSI.print(O);
734 }
735 
736 bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) {
737   auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
738   SSI = {&F, [SE]() -> ScalarEvolution & { return *SE; }};
739   return false;
740 }
741 
742 AnalysisKey StackSafetyGlobalAnalysis::Key;
743 
744 StackSafetyGlobalInfo
745 StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
746   FunctionAnalysisManager &FAM =
747       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
748   return {&M, [&FAM](Function &F) -> const StackSafetyInfo & {
749             return FAM.getResult<StackSafetyAnalysis>(F);
750           }};
751 }
752 
753 PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
754                                                     ModuleAnalysisManager &AM) {
755   OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n";
756   AM.getResult<StackSafetyGlobalAnalysis>(M).print(OS);
757   return PreservedAnalyses::all();
758 }
759 
760 char StackSafetyGlobalInfoWrapperPass::ID = 0;
761 
762 StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
763     : ModulePass(ID) {
764   initializeStackSafetyGlobalInfoWrapperPassPass(
765       *PassRegistry::getPassRegistry());
766 }
767 
768 StackSafetyGlobalInfoWrapperPass::~StackSafetyGlobalInfoWrapperPass() = default;
769 
770 void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
771                                              const Module *M) const {
772   SSGI.print(O);
773 }
774 
775 void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
776     AnalysisUsage &AU) const {
777   AU.setPreservesAll();
778   AU.addRequired<StackSafetyInfoWrapperPass>();
779 }
780 
781 bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
782   SSGI = {&M, [this](Function &F) -> const StackSafetyInfo & {
783             return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
784           }};
785   return false;
786 }
787 
788 static const char LocalPassArg[] = "stack-safety-local";
789 static const char LocalPassName[] = "Stack Safety Local Analysis";
790 INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
791                       false, true)
792 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
793 INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
794                     false, true)
795 
796 static const char GlobalPassName[] = "Stack Safety Analysis";
797 INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
798                       GlobalPassName, false, true)
799 INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass)
800 INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
801                     GlobalPassName, false, true)
802