1 //===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "llvm/Analysis/StackSafetyAnalysis.h"
12 #include "llvm/ADT/APInt.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
15 #include "llvm/IR/ConstantRange.h"
16 #include "llvm/IR/DerivedTypes.h"
17 #include "llvm/IR/GlobalValue.h"
18 #include "llvm/IR/InstIterator.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/IntrinsicInst.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include <algorithm>
27 #include <memory>
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "stack-safety"
32 
33 static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations",
34                                              cl::init(20), cl::Hidden);
35 
36 static cl::opt<int> StackSafetyPrint("stack-safety-print", cl::init(0),
37                                      cl::Hidden);
38 
39 namespace {
40 
41 /// Rewrite an SCEV expression for a memory access address to an expression that
42 /// represents offset from the given alloca.
43 class AllocaOffsetRewriter : public SCEVRewriteVisitor<AllocaOffsetRewriter> {
44   const Value *AllocaPtr;
45 
46 public:
47   AllocaOffsetRewriter(ScalarEvolution &SE, const Value *AllocaPtr)
48       : SCEVRewriteVisitor(SE), AllocaPtr(AllocaPtr) {}
49 
50   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
51     // FIXME: look through one or several levels of definitions?
52     // This can be inttoptr(AllocaPtr) and SCEV would not unwrap
53     // it for us.
54     if (Expr->getValue() == AllocaPtr)
55       return SE.getZero(Expr->getType());
56     return Expr;
57   }
58 };
59 
60 /// Describes use of address in as a function call argument.
61 template <typename CalleeTy> struct CallInfo {
62   /// Function being called.
63   const CalleeTy *Callee = nullptr;
64   /// Index of argument which pass address.
65   size_t ParamNo = 0;
66   // Offset range of address from base address (alloca or calling function
67   // argument).
68   // Range should never set to empty-set, that is an invalid access range
69   // that can cause empty-set to be propagated with ConstantRange::add
70   ConstantRange Offset;
71   CallInfo(const CalleeTy *Callee, size_t ParamNo, ConstantRange Offset)
72       : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {}
73 };
74 
75 template <typename CalleeTy>
76 raw_ostream &operator<<(raw_ostream &OS, const CallInfo<CalleeTy> &P) {
77   return OS << "@" << P.Callee->getName() << "(arg" << P.ParamNo << ", "
78             << P.Offset << ")";
79 }
80 
81 /// Describe uses of address (alloca or parameter) inside of the function.
82 template <typename CalleeTy> struct UseInfo {
83   // Access range if the address (alloca or parameters).
84   // It is allowed to be empty-set when there are no known accesses.
85   ConstantRange Range;
86 
87   // List of calls which pass address as an argument.
88   SmallVector<CallInfo<CalleeTy>, 4> Calls;
89 
90   UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
91 
92   void updateRange(const ConstantRange &R) {
93     assert(!R.isUpperSignWrapped());
94     Range = Range.unionWith(R);
95     assert(!Range.isUpperSignWrapped());
96   }
97 };
98 
99 template <typename CalleeTy>
100 raw_ostream &operator<<(raw_ostream &OS, const UseInfo<CalleeTy> &U) {
101   OS << U.Range;
102   for (auto &Call : U.Calls)
103     OS << ", " << Call;
104   return OS;
105 }
106 
107 // Check if we should bailout for such ranges.
108 bool isUnsafe(const ConstantRange &R) {
109   return R.isEmptySet() || R.isFullSet() || R.isUpperSignWrapped();
110 }
111 
112 /// Calculate the allocation size of a given alloca. Returns empty range
113 // in case of confution.
114 ConstantRange getStaticAllocaSizeRange(const AllocaInst &AI) {
115   const DataLayout &DL = AI.getModule()->getDataLayout();
116   TypeSize TS = DL.getTypeAllocSize(AI.getAllocatedType());
117   unsigned PointerSize = DL.getMaxPointerSizeInBits();
118   // Fallback to empty range for alloca size.
119   ConstantRange R = ConstantRange::getEmpty(PointerSize);
120   if (TS.isScalable())
121     return R;
122   APInt APSize(PointerSize, TS.getFixedSize(), true);
123   if (APSize.isNonPositive())
124     return R;
125   if (AI.isArrayAllocation()) {
126     const auto *C = dyn_cast<ConstantInt>(AI.getArraySize());
127     if (!C)
128       return R;
129     bool Overflow = false;
130     APInt Mul = C->getValue();
131     if (Mul.isNonPositive())
132       return R;
133     Mul = Mul.sextOrTrunc(PointerSize);
134     APSize = APSize.smul_ov(Mul, Overflow);
135     if (Overflow)
136       return R;
137   }
138   R = ConstantRange(APInt::getNullValue(PointerSize), APSize);
139   assert(!isUnsafe(R));
140   return R;
141 }
142 
143 template <typename CalleeTy> struct FunctionInfo {
144   std::map<const AllocaInst *, UseInfo<CalleeTy>> Allocas;
145   std::map<uint32_t, UseInfo<CalleeTy>> Params;
146   // TODO: describe return value as depending on one or more of its arguments.
147 
148   // StackSafetyDataFlowAnalysis counter stored here for faster access.
149   int UpdateCount = 0;
150 
151   void print(raw_ostream &O, StringRef Name, const Function *F) const {
152     // TODO: Consider different printout format after
153     // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then.
154     O << "  @" << Name << ((F && F->isDSOLocal()) ? "" : " dso_preemptable")
155       << ((F && F->isInterposable()) ? " interposable" : "") << "\n";
156 
157     O << "    args uses:\n";
158     for (auto &KV : Params) {
159       O << "      ";
160       if (F)
161         O << F->getArg(KV.first)->getName();
162       else
163         O << formatv("arg{0}", KV.first);
164       O << "[]: " << KV.second << "\n";
165     }
166 
167     O << "    allocas uses:\n";
168     if (F) {
169       for (auto &I : instructions(F)) {
170         if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
171           auto &AS = Allocas.find(AI)->second;
172           O << "      " << AI->getName() << "["
173             << getStaticAllocaSizeRange(*AI).getUpper() << "]: " << AS << "\n";
174         }
175       }
176     } else {
177       assert(Allocas.empty());
178     }
179   }
180 };
181 
182 using GVToSSI = std::map<const GlobalValue *, FunctionInfo<GlobalValue>>;
183 
184 } // namespace
185 
186 struct StackSafetyInfo::InfoTy {
187   FunctionInfo<GlobalValue> Info;
188 };
189 
190 struct StackSafetyGlobalInfo::InfoTy {
191   GVToSSI Info;
192   SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
193 };
194 
195 namespace {
196 
197 class StackSafetyLocalAnalysis {
198   Function &F;
199   const DataLayout &DL;
200   ScalarEvolution &SE;
201   unsigned PointerSize = 0;
202 
203   const ConstantRange UnknownRange;
204 
205   ConstantRange offsetFrom(Value *Addr, Value *Base);
206   ConstantRange getAccessRange(Value *Addr, Value *Base,
207                                ConstantRange SizeRange);
208   ConstantRange getAccessRange(Value *Addr, Value *Base, TypeSize Size);
209   ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U,
210                                            Value *Base);
211 
212   bool analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS);
213 
214 public:
215   StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
216       : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
217         PointerSize(DL.getPointerSizeInBits()),
218         UnknownRange(PointerSize, true) {}
219 
220   // Run the transformation on the associated function.
221   FunctionInfo<GlobalValue> run();
222 };
223 
224 ConstantRange StackSafetyLocalAnalysis::offsetFrom(Value *Addr, Value *Base) {
225   if (!SE.isSCEVable(Addr->getType()))
226     return UnknownRange;
227 
228   AllocaOffsetRewriter Rewriter(SE, Base);
229   const SCEV *Expr = Rewriter.visit(SE.getSCEV(Addr));
230   ConstantRange Offset = SE.getSignedRange(Expr);
231   if (isUnsafe(Offset))
232     return UnknownRange;
233   return Offset.sextOrTrunc(PointerSize);
234 }
235 
236 ConstantRange
237 StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
238                                          ConstantRange SizeRange) {
239   // Zero-size loads and stores do not access memory.
240   if (SizeRange.isEmptySet())
241     return ConstantRange::getEmpty(PointerSize);
242   assert(!isUnsafe(SizeRange));
243 
244   ConstantRange Offsets = offsetFrom(Addr, Base);
245   if (isUnsafe(Offsets))
246     return UnknownRange;
247 
248   if (Offsets.signedAddMayOverflow(SizeRange) !=
249       ConstantRange::OverflowResult::NeverOverflows)
250     return UnknownRange;
251   Offsets = Offsets.add(SizeRange);
252   if (isUnsafe(Offsets))
253     return UnknownRange;
254   return Offsets;
255 }
256 
257 ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
258                                                        TypeSize Size) {
259   if (Size.isScalable())
260     return UnknownRange;
261   APInt APSize(PointerSize, Size.getFixedSize(), true);
262   if (APSize.isNegative())
263     return UnknownRange;
264   return getAccessRange(
265       Addr, Base, ConstantRange(APInt::getNullValue(PointerSize), APSize));
266 }
267 
268 ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
269     const MemIntrinsic *MI, const Use &U, Value *Base) {
270   if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
271     if (MTI->getRawSource() != U && MTI->getRawDest() != U)
272       return ConstantRange::getEmpty(PointerSize);
273   } else {
274     if (MI->getRawDest() != U)
275       return ConstantRange::getEmpty(PointerSize);
276   }
277 
278   auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
279   if (!SE.isSCEVable(MI->getLength()->getType()))
280     return UnknownRange;
281 
282   const SCEV *Expr =
283       SE.getTruncateOrZeroExtend(SE.getSCEV(MI->getLength()), CalculationTy);
284   ConstantRange Sizes = SE.getSignedRange(Expr);
285   if (Sizes.getUpper().isNegative() || isUnsafe(Sizes))
286     return UnknownRange;
287   Sizes = Sizes.sextOrTrunc(PointerSize);
288   ConstantRange SizeRange(APInt::getNullValue(PointerSize),
289                           Sizes.getUpper() - 1);
290   return getAccessRange(U, Base, SizeRange);
291 }
292 
293 /// The function analyzes all local uses of Ptr (alloca or argument) and
294 /// calculates local access range and all function calls where it was used.
295 bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
296                                               UseInfo<GlobalValue> &US) {
297   SmallPtrSet<const Value *, 16> Visited;
298   SmallVector<const Value *, 8> WorkList;
299   WorkList.push_back(Ptr);
300 
301   // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
302   while (!WorkList.empty()) {
303     const Value *V = WorkList.pop_back_val();
304     for (const Use &UI : V->uses()) {
305       const auto *I = cast<const Instruction>(UI.getUser());
306       assert(V == UI.get());
307 
308       switch (I->getOpcode()) {
309       case Instruction::Load: {
310         US.updateRange(
311             getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
312         break;
313       }
314 
315       case Instruction::VAArg:
316         // "va-arg" from a pointer is safe.
317         break;
318       case Instruction::Store: {
319         if (V == I->getOperand(0)) {
320           // Stored the pointer - conservatively assume it may be unsafe.
321           US.updateRange(UnknownRange);
322           return false;
323         }
324         US.updateRange(getAccessRange(
325             UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
326         break;
327       }
328 
329       case Instruction::Ret:
330         // Information leak.
331         // FIXME: Process parameters correctly. This is a leak only if we return
332         // alloca.
333         US.updateRange(UnknownRange);
334         return false;
335 
336       case Instruction::Call:
337       case Instruction::Invoke: {
338         const auto &CB = cast<CallBase>(*I);
339 
340         if (I->isLifetimeStartOrEnd())
341           break;
342 
343         if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
344           US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr));
345           break;
346         }
347 
348         // FIXME: consult devirt?
349         // Do not follow aliases, otherwise we could inadvertently follow
350         // dso_preemptable aliases or aliases with interposable linkage.
351         const GlobalValue *Callee =
352             dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
353         if (!Callee) {
354           US.updateRange(UnknownRange);
355           return false;
356         }
357 
358         assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee));
359 
360         int Found = 0;
361         for (size_t ArgNo = 0; ArgNo < CB.getNumArgOperands(); ++ArgNo) {
362           if (CB.getArgOperand(ArgNo) == V) {
363             ++Found;
364             US.Calls.emplace_back(Callee, ArgNo, offsetFrom(UI, Ptr));
365           }
366         }
367         if (!Found) {
368           US.updateRange(UnknownRange);
369           return false;
370         }
371 
372         break;
373       }
374 
375       default:
376         if (Visited.insert(I).second)
377           WorkList.push_back(cast<const Instruction>(I));
378       }
379     }
380   }
381 
382   return true;
383 }
384 
385 FunctionInfo<GlobalValue> StackSafetyLocalAnalysis::run() {
386   FunctionInfo<GlobalValue> Info;
387   assert(!F.isDeclaration() &&
388          "Can't run StackSafety on a function declaration");
389 
390   LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n");
391 
392   for (auto &I : instructions(F)) {
393     if (auto *AI = dyn_cast<AllocaInst>(&I)) {
394       auto &UI = Info.Allocas.emplace(AI, PointerSize).first->second;
395       analyzeAllUses(AI, UI);
396     }
397   }
398 
399   for (Argument &A : make_range(F.arg_begin(), F.arg_end())) {
400     if (A.getType()->isPointerTy()) {
401       auto &UI = Info.Params.emplace(A.getArgNo(), PointerSize).first->second;
402       analyzeAllUses(&A, UI);
403     }
404   }
405 
406   LLVM_DEBUG(Info.print(dbgs(), F.getName(), &F));
407   LLVM_DEBUG(dbgs() << "[StackSafety] done\n");
408   return Info;
409 }
410 
411 template <typename CalleeTy> class StackSafetyDataFlowAnalysis {
412   using FunctionMap = std::map<const CalleeTy *, FunctionInfo<CalleeTy>>;
413 
414   FunctionMap Functions;
415   const ConstantRange UnknownRange;
416 
417   // Callee-to-Caller multimap.
418   DenseMap<const CalleeTy *, SmallVector<const CalleeTy *, 4>> Callers;
419   SetVector<const CalleeTy *> WorkList;
420 
421   bool updateOneUse(UseInfo<CalleeTy> &US, bool UpdateToFullSet);
422   void updateOneNode(const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS);
423   void updateOneNode(const CalleeTy *Callee) {
424     updateOneNode(Callee, Functions.find(Callee)->second);
425   }
426   void updateAllNodes() {
427     for (auto &F : Functions)
428       updateOneNode(F.first, F.second);
429   }
430   void runDataFlow();
431 #ifndef NDEBUG
432   void verifyFixedPoint();
433 #endif
434 
435 public:
436   StackSafetyDataFlowAnalysis(uint32_t PointerBitWidth, FunctionMap Functions)
437       : Functions(std::move(Functions)),
438         UnknownRange(ConstantRange::getFull(PointerBitWidth)) {}
439 
440   const FunctionMap &run();
441 
442   ConstantRange getArgumentAccessRange(const CalleeTy *Callee, unsigned ParamNo,
443                                        const ConstantRange &Offsets) const;
444 };
445 
446 template <typename CalleeTy>
447 ConstantRange StackSafetyDataFlowAnalysis<CalleeTy>::getArgumentAccessRange(
448     const CalleeTy *Callee, unsigned ParamNo,
449     const ConstantRange &Offsets) const {
450   auto FnIt = Functions.find(Callee);
451   // Unknown callee (outside of LTO domain or an indirect call).
452   if (FnIt == Functions.end())
453     return UnknownRange;
454   auto &FS = FnIt->second;
455   auto ParamIt = FS.Params.find(ParamNo);
456   if (ParamIt == FS.Params.end())
457     return UnknownRange;
458   auto &Access = ParamIt->second.Range;
459   if (Access.isEmptySet())
460     return Access;
461   if (Access.isFullSet())
462     return UnknownRange;
463   if (Offsets.signedAddMayOverflow(Access) !=
464       ConstantRange::OverflowResult::NeverOverflows)
465     return UnknownRange;
466   return Access.add(Offsets);
467 }
468 
469 template <typename CalleeTy>
470 bool StackSafetyDataFlowAnalysis<CalleeTy>::updateOneUse(UseInfo<CalleeTy> &US,
471                                                          bool UpdateToFullSet) {
472   bool Changed = false;
473   for (auto &CS : US.Calls) {
474     assert(!CS.Offset.isEmptySet() &&
475            "Param range can't be empty-set, invalid offset range");
476 
477     ConstantRange CalleeRange =
478         getArgumentAccessRange(CS.Callee, CS.ParamNo, CS.Offset);
479     if (!US.Range.contains(CalleeRange)) {
480       Changed = true;
481       if (UpdateToFullSet)
482         US.Range = UnknownRange;
483       else
484         US.Range = US.Range.unionWith(CalleeRange);
485     }
486   }
487   return Changed;
488 }
489 
490 template <typename CalleeTy>
491 void StackSafetyDataFlowAnalysis<CalleeTy>::updateOneNode(
492     const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS) {
493   bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations;
494   bool Changed = false;
495   for (auto &KV : FS.Params)
496     Changed |= updateOneUse(KV.second, UpdateToFullSet);
497 
498   if (Changed) {
499     LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount
500                       << (UpdateToFullSet ? ", full-set" : "") << "] " << &FS
501                       << "\n");
502     // Callers of this function may need updating.
503     for (auto &CallerID : Callers[Callee])
504       WorkList.insert(CallerID);
505 
506     ++FS.UpdateCount;
507   }
508 }
509 
510 template <typename CalleeTy>
511 void StackSafetyDataFlowAnalysis<CalleeTy>::runDataFlow() {
512   SmallVector<const CalleeTy *, 16> Callees;
513   for (auto &F : Functions) {
514     Callees.clear();
515     auto &FS = F.second;
516     for (auto &KV : FS.Params)
517       for (auto &CS : KV.second.Calls)
518         Callees.push_back(CS.Callee);
519 
520     llvm::sort(Callees);
521     Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end());
522 
523     for (auto &Callee : Callees)
524       Callers[Callee].push_back(F.first);
525   }
526 
527   updateAllNodes();
528 
529   while (!WorkList.empty()) {
530     const CalleeTy *Callee = WorkList.back();
531     WorkList.pop_back();
532     updateOneNode(Callee);
533   }
534 }
535 
536 #ifndef NDEBUG
537 template <typename CalleeTy>
538 void StackSafetyDataFlowAnalysis<CalleeTy>::verifyFixedPoint() {
539   WorkList.clear();
540   updateAllNodes();
541   assert(WorkList.empty());
542 }
543 #endif
544 
545 template <typename CalleeTy>
546 const typename StackSafetyDataFlowAnalysis<CalleeTy>::FunctionMap &
547 StackSafetyDataFlowAnalysis<CalleeTy>::run() {
548   runDataFlow();
549   LLVM_DEBUG(verifyFixedPoint());
550   return Functions;
551 }
552 
553 const Function *findCalleeInModule(const GlobalValue *GV) {
554   while (GV) {
555     if (GV->isInterposable() || !GV->isDSOLocal())
556       return nullptr;
557     if (const Function *F = dyn_cast<Function>(GV))
558       return F;
559     const GlobalAlias *A = dyn_cast<GlobalAlias>(GV);
560     if (!A)
561       return nullptr;
562     GV = A->getBaseObject();
563     if (GV == A)
564       return nullptr;
565   }
566   return nullptr;
567 }
568 
569 template <typename CalleeTy> void resolveAllCalls(UseInfo<CalleeTy> &Use) {
570   ConstantRange FullSet(Use.Range.getBitWidth(), true);
571   for (auto &C : Use.Calls) {
572     const Function *F = findCalleeInModule(C.Callee);
573     if (F) {
574       C.Callee = F;
575       continue;
576     }
577 
578     return Use.updateRange(FullSet);
579   }
580 }
581 
582 GVToSSI createGlobalStackSafetyInfo(
583     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions) {
584   GVToSSI SSI;
585   if (Functions.empty())
586     return SSI;
587 
588   // FIXME: Simplify printing and remove copying here.
589   auto Copy = Functions;
590 
591   for (auto &FnKV : Copy)
592     for (auto &KV : FnKV.second.Params)
593       resolveAllCalls(KV.second);
594 
595   uint32_t PointerSize = Copy.begin()
596                              ->first->getParent()
597                              ->getDataLayout()
598                              .getMaxPointerSizeInBits();
599   StackSafetyDataFlowAnalysis<GlobalValue> SSDFA(PointerSize, std::move(Copy));
600 
601   for (auto &F : SSDFA.run()) {
602     auto FI = F.second;
603     auto &SrcF = Functions[F.first];
604     for (auto &KV : FI.Allocas) {
605       auto &A = KV.second;
606       resolveAllCalls(A);
607       for (auto &C : A.Calls) {
608         A.updateRange(
609             SSDFA.getArgumentAccessRange(C.Callee, C.ParamNo, C.Offset));
610       }
611       // FIXME: This is needed only to preserve calls in print() results.
612       A.Calls = SrcF.Allocas.find(KV.first)->second.Calls;
613     }
614     for (auto &KV : FI.Params) {
615       auto &P = KV.second;
616       P.Calls = SrcF.Params.find(KV.first)->second.Calls;
617     }
618     SSI[F.first] = std::move(FI);
619   }
620 
621   return SSI;
622 }
623 
624 } // end anonymous namespace
625 
626 StackSafetyInfo::StackSafetyInfo() = default;
627 
628 StackSafetyInfo::StackSafetyInfo(Function *F,
629                                  std::function<ScalarEvolution &()> GetSE)
630     : F(F), GetSE(GetSE) {}
631 
632 StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default;
633 
634 StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default;
635 
636 StackSafetyInfo::~StackSafetyInfo() = default;
637 
638 const StackSafetyInfo::InfoTy &StackSafetyInfo::getInfo() const {
639   if (!Info) {
640     StackSafetyLocalAnalysis SSLA(*F, GetSE());
641     Info.reset(new InfoTy{SSLA.run()});
642   }
643   return *Info;
644 }
645 
646 void StackSafetyInfo::print(raw_ostream &O) const {
647   getInfo().Info.print(O, F->getName(), dyn_cast<Function>(F));
648 }
649 
650 const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
651   if (!Info) {
652     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions;
653     for (auto &F : M->functions()) {
654       if (!F.isDeclaration()) {
655         auto FI = GetSSI(F).getInfo().Info;
656         Functions.emplace(&F, std::move(FI));
657       }
658     }
659     Info.reset(
660         new InfoTy{createGlobalStackSafetyInfo(std::move(Functions)), {}});
661     for (auto &FnKV : Info->Info) {
662       for (auto &KV : FnKV.second.Allocas) {
663         const AllocaInst *AI = KV.first;
664         if (getStaticAllocaSizeRange(*AI).contains(KV.second.Range))
665           Info->SafeAllocas.insert(AI);
666       }
667     }
668     if (StackSafetyPrint)
669       print(errs());
670   }
671   return *Info;
672 }
673 
674 StackSafetyGlobalInfo::StackSafetyGlobalInfo() = default;
675 
676 StackSafetyGlobalInfo::StackSafetyGlobalInfo(
677     Module *M, std::function<const StackSafetyInfo &(Function &F)> GetSSI)
678     : M(M), GetSSI(GetSSI) {
679   if (StackSafetyPrint > 1)
680     getInfo();
681 }
682 
683 StackSafetyGlobalInfo::StackSafetyGlobalInfo(StackSafetyGlobalInfo &&) =
684     default;
685 
686 StackSafetyGlobalInfo &
687 StackSafetyGlobalInfo::operator=(StackSafetyGlobalInfo &&) = default;
688 
689 StackSafetyGlobalInfo::~StackSafetyGlobalInfo() = default;
690 
691 bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
692   const auto &Info = getInfo();
693   return Info.SafeAllocas.find(&AI) != Info.SafeAllocas.end();
694 }
695 
696 void StackSafetyGlobalInfo::print(raw_ostream &O) const {
697   auto &SSI = getInfo().Info;
698   if (SSI.empty())
699     return;
700   const Module &M = *SSI.begin()->first->getParent();
701   for (auto &F : M.functions()) {
702     if (!F.isDeclaration()) {
703       SSI.find(&F)->second.print(O, F.getName(), &F);
704       O << "\n";
705     }
706   }
707 }
708 
709 LLVM_DUMP_METHOD void StackSafetyGlobalInfo::dump() const { print(dbgs()); }
710 
711 AnalysisKey StackSafetyAnalysis::Key;
712 
713 StackSafetyInfo StackSafetyAnalysis::run(Function &F,
714                                          FunctionAnalysisManager &AM) {
715   return StackSafetyInfo(&F, [&AM, &F]() -> ScalarEvolution & {
716     return AM.getResult<ScalarEvolutionAnalysis>(F);
717   });
718 }
719 
720 PreservedAnalyses StackSafetyPrinterPass::run(Function &F,
721                                               FunctionAnalysisManager &AM) {
722   OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n";
723   AM.getResult<StackSafetyAnalysis>(F).print(OS);
724   return PreservedAnalyses::all();
725 }
726 
727 char StackSafetyInfoWrapperPass::ID = 0;
728 
729 StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) {
730   initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
731 }
732 
733 void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
734   AU.addRequiredTransitive<ScalarEvolutionWrapperPass>();
735   AU.setPreservesAll();
736 }
737 
738 void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const {
739   SSI.print(O);
740 }
741 
742 bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) {
743   auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
744   SSI = {&F, [SE]() -> ScalarEvolution & { return *SE; }};
745   return false;
746 }
747 
748 AnalysisKey StackSafetyGlobalAnalysis::Key;
749 
750 StackSafetyGlobalInfo
751 StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
752   FunctionAnalysisManager &FAM =
753       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
754   return {&M, [&FAM](Function &F) -> const StackSafetyInfo & {
755             return FAM.getResult<StackSafetyAnalysis>(F);
756           }};
757 }
758 
759 PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
760                                                     ModuleAnalysisManager &AM) {
761   OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n";
762   AM.getResult<StackSafetyGlobalAnalysis>(M).print(OS);
763   return PreservedAnalyses::all();
764 }
765 
766 char StackSafetyGlobalInfoWrapperPass::ID = 0;
767 
768 StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
769     : ModulePass(ID) {
770   initializeStackSafetyGlobalInfoWrapperPassPass(
771       *PassRegistry::getPassRegistry());
772 }
773 
774 StackSafetyGlobalInfoWrapperPass::~StackSafetyGlobalInfoWrapperPass() = default;
775 
776 void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
777                                              const Module *M) const {
778   SSGI.print(O);
779 }
780 
781 void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
782     AnalysisUsage &AU) const {
783   AU.setPreservesAll();
784   AU.addRequired<StackSafetyInfoWrapperPass>();
785 }
786 
787 bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
788   SSGI = {&M, [this](Function &F) -> const StackSafetyInfo & {
789             return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
790           }};
791   return false;
792 }
793 
794 static const char LocalPassArg[] = "stack-safety-local";
795 static const char LocalPassName[] = "Stack Safety Local Analysis";
796 INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
797                       false, true)
798 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
799 INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
800                     false, true)
801 
802 static const char GlobalPassName[] = "Stack Safety Analysis";
803 INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
804                       GlobalPassName, false, true)
805 INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass)
806 INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
807                     GlobalPassName, false, true)
808