1 //===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "llvm/Analysis/StackSafetyAnalysis.h"
12 #include "llvm/ADT/APInt.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/Statistic.h"
15 #include "llvm/Analysis/ModuleSummaryAnalysis.h"
16 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
17 #include "llvm/IR/ConstantRange.h"
18 #include "llvm/IR/DerivedTypes.h"
19 #include "llvm/IR/GlobalValue.h"
20 #include "llvm/IR/InstIterator.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <algorithm>
29 #include <memory>
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "stack-safety"
34 
35 STATISTIC(NumAllocaStackSafe, "Number of safe allocas");
36 STATISTIC(NumAllocaTotal, "Number of total allocas");
37 
38 static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations",
39                                              cl::init(20), cl::Hidden);
40 
41 static cl::opt<bool> StackSafetyPrint("stack-safety-print", cl::init(false),
42                                       cl::Hidden);
43 
44 static cl::opt<bool> StackSafetyRun("stack-safety-run", cl::init(false),
45                                     cl::Hidden);
46 
47 namespace {
48 
49 /// Describes use of address in as a function call argument.
50 template <typename CalleeTy> struct CallInfo {
51   /// Function being called.
52   const CalleeTy *Callee = nullptr;
53   /// Index of argument which pass address.
54   size_t ParamNo = 0;
55   // Offset range of address from base address (alloca or calling function
56   // argument).
57   // Range should never set to empty-set, that is an invalid access range
58   // that can cause empty-set to be propagated with ConstantRange::add
59   ConstantRange Offset;
60   CallInfo(const CalleeTy *Callee, size_t ParamNo, ConstantRange Offset)
61       : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {}
62 };
63 
64 template <typename CalleeTy>
65 raw_ostream &operator<<(raw_ostream &OS, const CallInfo<CalleeTy> &P) {
66   return OS << "@" << P.Callee->getName() << "(arg" << P.ParamNo << ", "
67             << P.Offset << ")";
68 }
69 
70 /// Describe uses of address (alloca or parameter) inside of the function.
71 template <typename CalleeTy> struct UseInfo {
72   // Access range if the address (alloca or parameters).
73   // It is allowed to be empty-set when there are no known accesses.
74   ConstantRange Range;
75 
76   // List of calls which pass address as an argument.
77   SmallVector<CallInfo<CalleeTy>, 4> Calls;
78 
79   UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
80 
81   void updateRange(const ConstantRange &R) {
82     assert(!R.isUpperSignWrapped());
83     Range = Range.unionWith(R);
84     assert(!Range.isUpperSignWrapped());
85   }
86 };
87 
88 template <typename CalleeTy>
89 raw_ostream &operator<<(raw_ostream &OS, const UseInfo<CalleeTy> &U) {
90   OS << U.Range;
91   for (auto &Call : U.Calls)
92     OS << ", " << Call;
93   return OS;
94 }
95 
96 // Check if we should bailout for such ranges.
97 bool isUnsafe(const ConstantRange &R) {
98   return R.isEmptySet() || R.isFullSet() || R.isUpperSignWrapped();
99 }
100 
101 /// Calculate the allocation size of a given alloca. Returns empty range
102 // in case of confution.
103 ConstantRange getStaticAllocaSizeRange(const AllocaInst &AI) {
104   const DataLayout &DL = AI.getModule()->getDataLayout();
105   TypeSize TS = DL.getTypeAllocSize(AI.getAllocatedType());
106   unsigned PointerSize = DL.getMaxPointerSizeInBits();
107   // Fallback to empty range for alloca size.
108   ConstantRange R = ConstantRange::getEmpty(PointerSize);
109   if (TS.isScalable())
110     return R;
111   APInt APSize(PointerSize, TS.getFixedSize(), true);
112   if (APSize.isNonPositive())
113     return R;
114   if (AI.isArrayAllocation()) {
115     const auto *C = dyn_cast<ConstantInt>(AI.getArraySize());
116     if (!C)
117       return R;
118     bool Overflow = false;
119     APInt Mul = C->getValue();
120     if (Mul.isNonPositive())
121       return R;
122     Mul = Mul.sextOrTrunc(PointerSize);
123     APSize = APSize.smul_ov(Mul, Overflow);
124     if (Overflow)
125       return R;
126   }
127   R = ConstantRange(APInt::getNullValue(PointerSize), APSize);
128   assert(!isUnsafe(R));
129   return R;
130 }
131 
132 template <typename CalleeTy> struct FunctionInfo {
133   std::map<const AllocaInst *, UseInfo<CalleeTy>> Allocas;
134   std::map<uint32_t, UseInfo<CalleeTy>> Params;
135   // TODO: describe return value as depending on one or more of its arguments.
136 
137   // StackSafetyDataFlowAnalysis counter stored here for faster access.
138   int UpdateCount = 0;
139 
140   void print(raw_ostream &O, StringRef Name, const Function *F) const {
141     // TODO: Consider different printout format after
142     // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then.
143     O << "  @" << Name << ((F && F->isDSOLocal()) ? "" : " dso_preemptable")
144       << ((F && F->isInterposable()) ? " interposable" : "") << "\n";
145 
146     O << "    args uses:\n";
147     for (auto &KV : Params) {
148       O << "      ";
149       if (F)
150         O << F->getArg(KV.first)->getName();
151       else
152         O << formatv("arg{0}", KV.first);
153       O << "[]: " << KV.second << "\n";
154     }
155 
156     O << "    allocas uses:\n";
157     if (F) {
158       for (auto &I : instructions(F)) {
159         if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
160           auto &AS = Allocas.find(AI)->second;
161           O << "      " << AI->getName() << "["
162             << getStaticAllocaSizeRange(*AI).getUpper() << "]: " << AS << "\n";
163         }
164       }
165     } else {
166       assert(Allocas.empty());
167     }
168   }
169 };
170 
171 using GVToSSI = std::map<const GlobalValue *, FunctionInfo<GlobalValue>>;
172 
173 } // namespace
174 
175 struct StackSafetyInfo::InfoTy {
176   FunctionInfo<GlobalValue> Info;
177 };
178 
179 struct StackSafetyGlobalInfo::InfoTy {
180   GVToSSI Info;
181   SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
182 };
183 
184 namespace {
185 
186 class StackSafetyLocalAnalysis {
187   Function &F;
188   const DataLayout &DL;
189   ScalarEvolution &SE;
190   unsigned PointerSize = 0;
191 
192   const ConstantRange UnknownRange;
193 
194   ConstantRange offsetFrom(Value *Addr, Value *Base);
195   ConstantRange getAccessRange(Value *Addr, Value *Base,
196                                ConstantRange SizeRange);
197   ConstantRange getAccessRange(Value *Addr, Value *Base, TypeSize Size);
198   ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U,
199                                            Value *Base);
200 
201   bool analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS);
202 
203 public:
204   StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
205       : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
206         PointerSize(DL.getPointerSizeInBits()),
207         UnknownRange(PointerSize, true) {}
208 
209   // Run the transformation on the associated function.
210   FunctionInfo<GlobalValue> run();
211 };
212 
213 ConstantRange StackSafetyLocalAnalysis::offsetFrom(Value *Addr, Value *Base) {
214   if (!SE.isSCEVable(Addr->getType()) || !SE.isSCEVable(Base->getType()))
215     return UnknownRange;
216 
217   auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
218   const SCEV *AddrExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Addr), PtrTy);
219   const SCEV *BaseExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Base), PtrTy);
220   const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
221 
222   ConstantRange Offset = SE.getSignedRange(Diff);
223   if (isUnsafe(Offset))
224     return UnknownRange;
225   return Offset.sextOrTrunc(PointerSize);
226 }
227 
228 ConstantRange
229 StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
230                                          ConstantRange SizeRange) {
231   // Zero-size loads and stores do not access memory.
232   if (SizeRange.isEmptySet())
233     return ConstantRange::getEmpty(PointerSize);
234   assert(!isUnsafe(SizeRange));
235 
236   ConstantRange Offsets = offsetFrom(Addr, Base);
237   if (isUnsafe(Offsets))
238     return UnknownRange;
239 
240   if (Offsets.signedAddMayOverflow(SizeRange) !=
241       ConstantRange::OverflowResult::NeverOverflows)
242     return UnknownRange;
243   Offsets = Offsets.add(SizeRange);
244   if (isUnsafe(Offsets))
245     return UnknownRange;
246   return Offsets;
247 }
248 
249 ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
250                                                        TypeSize Size) {
251   if (Size.isScalable())
252     return UnknownRange;
253   APInt APSize(PointerSize, Size.getFixedSize(), true);
254   if (APSize.isNegative())
255     return UnknownRange;
256   return getAccessRange(
257       Addr, Base, ConstantRange(APInt::getNullValue(PointerSize), APSize));
258 }
259 
260 ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
261     const MemIntrinsic *MI, const Use &U, Value *Base) {
262   if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
263     if (MTI->getRawSource() != U && MTI->getRawDest() != U)
264       return ConstantRange::getEmpty(PointerSize);
265   } else {
266     if (MI->getRawDest() != U)
267       return ConstantRange::getEmpty(PointerSize);
268   }
269 
270   auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
271   if (!SE.isSCEVable(MI->getLength()->getType()))
272     return UnknownRange;
273 
274   const SCEV *Expr =
275       SE.getTruncateOrZeroExtend(SE.getSCEV(MI->getLength()), CalculationTy);
276   ConstantRange Sizes = SE.getSignedRange(Expr);
277   if (Sizes.getUpper().isNegative() || isUnsafe(Sizes))
278     return UnknownRange;
279   Sizes = Sizes.sextOrTrunc(PointerSize);
280   ConstantRange SizeRange(APInt::getNullValue(PointerSize),
281                           Sizes.getUpper() - 1);
282   return getAccessRange(U, Base, SizeRange);
283 }
284 
285 /// The function analyzes all local uses of Ptr (alloca or argument) and
286 /// calculates local access range and all function calls where it was used.
287 bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
288                                               UseInfo<GlobalValue> &US) {
289   SmallPtrSet<const Value *, 16> Visited;
290   SmallVector<const Value *, 8> WorkList;
291   WorkList.push_back(Ptr);
292 
293   // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
294   while (!WorkList.empty()) {
295     const Value *V = WorkList.pop_back_val();
296     for (const Use &UI : V->uses()) {
297       const auto *I = cast<const Instruction>(UI.getUser());
298       assert(V == UI.get());
299 
300       switch (I->getOpcode()) {
301       case Instruction::Load: {
302         US.updateRange(
303             getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
304         break;
305       }
306 
307       case Instruction::VAArg:
308         // "va-arg" from a pointer is safe.
309         break;
310       case Instruction::Store: {
311         if (V == I->getOperand(0)) {
312           // Stored the pointer - conservatively assume it may be unsafe.
313           US.updateRange(UnknownRange);
314           return false;
315         }
316         US.updateRange(getAccessRange(
317             UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
318         break;
319       }
320 
321       case Instruction::Ret:
322         // Information leak.
323         // FIXME: Process parameters correctly. This is a leak only if we return
324         // alloca.
325         US.updateRange(UnknownRange);
326         return false;
327 
328       case Instruction::Call:
329       case Instruction::Invoke: {
330         const auto &CB = cast<CallBase>(*I);
331 
332         if (I->isLifetimeStartOrEnd())
333           break;
334 
335         if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
336           US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr));
337           break;
338         }
339 
340         // FIXME: consult devirt?
341         // Do not follow aliases, otherwise we could inadvertently follow
342         // dso_preemptable aliases or aliases with interposable linkage.
343         const GlobalValue *Callee =
344             dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
345         if (!Callee) {
346           US.updateRange(UnknownRange);
347           return false;
348         }
349 
350         assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee));
351 
352         int Found = 0;
353         for (size_t ArgNo = 0; ArgNo < CB.getNumArgOperands(); ++ArgNo) {
354           if (CB.getArgOperand(ArgNo) == V) {
355             ++Found;
356             US.Calls.emplace_back(Callee, ArgNo, offsetFrom(UI, Ptr));
357           }
358         }
359         if (!Found) {
360           US.updateRange(UnknownRange);
361           return false;
362         }
363 
364         break;
365       }
366 
367       default:
368         if (Visited.insert(I).second)
369           WorkList.push_back(cast<const Instruction>(I));
370       }
371     }
372   }
373 
374   return true;
375 }
376 
377 FunctionInfo<GlobalValue> StackSafetyLocalAnalysis::run() {
378   FunctionInfo<GlobalValue> Info;
379   assert(!F.isDeclaration() &&
380          "Can't run StackSafety on a function declaration");
381 
382   LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n");
383 
384   for (auto &I : instructions(F)) {
385     if (auto *AI = dyn_cast<AllocaInst>(&I)) {
386       auto &UI = Info.Allocas.emplace(AI, PointerSize).first->second;
387       analyzeAllUses(AI, UI);
388     }
389   }
390 
391   for (Argument &A : make_range(F.arg_begin(), F.arg_end())) {
392     if (A.getType()->isPointerTy()) {
393       auto &UI = Info.Params.emplace(A.getArgNo(), PointerSize).first->second;
394       analyzeAllUses(&A, UI);
395     }
396   }
397 
398   LLVM_DEBUG(Info.print(dbgs(), F.getName(), &F));
399   LLVM_DEBUG(dbgs() << "[StackSafety] done\n");
400   return Info;
401 }
402 
403 template <typename CalleeTy> class StackSafetyDataFlowAnalysis {
404   using FunctionMap = std::map<const CalleeTy *, FunctionInfo<CalleeTy>>;
405 
406   FunctionMap Functions;
407   const ConstantRange UnknownRange;
408 
409   // Callee-to-Caller multimap.
410   DenseMap<const CalleeTy *, SmallVector<const CalleeTy *, 4>> Callers;
411   SetVector<const CalleeTy *> WorkList;
412 
413   bool updateOneUse(UseInfo<CalleeTy> &US, bool UpdateToFullSet);
414   void updateOneNode(const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS);
415   void updateOneNode(const CalleeTy *Callee) {
416     updateOneNode(Callee, Functions.find(Callee)->second);
417   }
418   void updateAllNodes() {
419     for (auto &F : Functions)
420       updateOneNode(F.first, F.second);
421   }
422   void runDataFlow();
423 #ifndef NDEBUG
424   void verifyFixedPoint();
425 #endif
426 
427 public:
428   StackSafetyDataFlowAnalysis(uint32_t PointerBitWidth, FunctionMap Functions)
429       : Functions(std::move(Functions)),
430         UnknownRange(ConstantRange::getFull(PointerBitWidth)) {}
431 
432   const FunctionMap &run();
433 
434   ConstantRange getArgumentAccessRange(const CalleeTy *Callee, unsigned ParamNo,
435                                        const ConstantRange &Offsets) const;
436 };
437 
438 template <typename CalleeTy>
439 ConstantRange StackSafetyDataFlowAnalysis<CalleeTy>::getArgumentAccessRange(
440     const CalleeTy *Callee, unsigned ParamNo,
441     const ConstantRange &Offsets) const {
442   auto FnIt = Functions.find(Callee);
443   // Unknown callee (outside of LTO domain or an indirect call).
444   if (FnIt == Functions.end())
445     return UnknownRange;
446   auto &FS = FnIt->second;
447   auto ParamIt = FS.Params.find(ParamNo);
448   if (ParamIt == FS.Params.end())
449     return UnknownRange;
450   auto &Access = ParamIt->second.Range;
451   if (Access.isEmptySet())
452     return Access;
453   if (Access.isFullSet())
454     return UnknownRange;
455   if (Offsets.signedAddMayOverflow(Access) !=
456       ConstantRange::OverflowResult::NeverOverflows)
457     return UnknownRange;
458   return Access.add(Offsets);
459 }
460 
461 template <typename CalleeTy>
462 bool StackSafetyDataFlowAnalysis<CalleeTy>::updateOneUse(UseInfo<CalleeTy> &US,
463                                                          bool UpdateToFullSet) {
464   bool Changed = false;
465   for (auto &CS : US.Calls) {
466     assert(!CS.Offset.isEmptySet() &&
467            "Param range can't be empty-set, invalid offset range");
468 
469     ConstantRange CalleeRange =
470         getArgumentAccessRange(CS.Callee, CS.ParamNo, CS.Offset);
471     if (!US.Range.contains(CalleeRange)) {
472       Changed = true;
473       if (UpdateToFullSet)
474         US.Range = UnknownRange;
475       else
476         US.Range = US.Range.unionWith(CalleeRange);
477     }
478   }
479   return Changed;
480 }
481 
482 template <typename CalleeTy>
483 void StackSafetyDataFlowAnalysis<CalleeTy>::updateOneNode(
484     const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS) {
485   bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations;
486   bool Changed = false;
487   for (auto &KV : FS.Params)
488     Changed |= updateOneUse(KV.second, UpdateToFullSet);
489 
490   if (Changed) {
491     LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount
492                       << (UpdateToFullSet ? ", full-set" : "") << "] " << &FS
493                       << "\n");
494     // Callers of this function may need updating.
495     for (auto &CallerID : Callers[Callee])
496       WorkList.insert(CallerID);
497 
498     ++FS.UpdateCount;
499   }
500 }
501 
502 template <typename CalleeTy>
503 void StackSafetyDataFlowAnalysis<CalleeTy>::runDataFlow() {
504   SmallVector<const CalleeTy *, 16> Callees;
505   for (auto &F : Functions) {
506     Callees.clear();
507     auto &FS = F.second;
508     for (auto &KV : FS.Params)
509       for (auto &CS : KV.second.Calls)
510         Callees.push_back(CS.Callee);
511 
512     llvm::sort(Callees);
513     Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end());
514 
515     for (auto &Callee : Callees)
516       Callers[Callee].push_back(F.first);
517   }
518 
519   updateAllNodes();
520 
521   while (!WorkList.empty()) {
522     const CalleeTy *Callee = WorkList.back();
523     WorkList.pop_back();
524     updateOneNode(Callee);
525   }
526 }
527 
528 #ifndef NDEBUG
529 template <typename CalleeTy>
530 void StackSafetyDataFlowAnalysis<CalleeTy>::verifyFixedPoint() {
531   WorkList.clear();
532   updateAllNodes();
533   assert(WorkList.empty());
534 }
535 #endif
536 
537 template <typename CalleeTy>
538 const typename StackSafetyDataFlowAnalysis<CalleeTy>::FunctionMap &
539 StackSafetyDataFlowAnalysis<CalleeTy>::run() {
540   runDataFlow();
541   LLVM_DEBUG(verifyFixedPoint());
542   return Functions;
543 }
544 
545 const Function *findCalleeInModule(const GlobalValue *GV) {
546   while (GV) {
547     if (GV->isDeclaration() || GV->isInterposable() || !GV->isDSOLocal())
548       return nullptr;
549     if (const Function *F = dyn_cast<Function>(GV))
550       return F;
551     const GlobalAlias *A = dyn_cast<GlobalAlias>(GV);
552     if (!A)
553       return nullptr;
554     GV = A->getBaseObject();
555     if (GV == A)
556       return nullptr;
557   }
558   return nullptr;
559 }
560 
561 template <typename CalleeTy> void resolveAllCalls(UseInfo<CalleeTy> &Use) {
562   ConstantRange FullSet(Use.Range.getBitWidth(), true);
563   for (auto &C : Use.Calls) {
564     const Function *F = findCalleeInModule(C.Callee);
565     if (F) {
566       C.Callee = F;
567       continue;
568     }
569 
570     return Use.updateRange(FullSet);
571   }
572 }
573 
574 GVToSSI createGlobalStackSafetyInfo(
575     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions,
576     const ModuleSummaryIndex *Index) {
577   GVToSSI SSI;
578   if (Functions.empty())
579     return SSI;
580 
581   // FIXME: Simplify printing and remove copying here.
582   auto Copy = Functions;
583 
584   for (auto &FnKV : Copy)
585     for (auto &KV : FnKV.second.Params)
586       resolveAllCalls(KV.second);
587 
588   uint32_t PointerSize = Copy.begin()
589                              ->first->getParent()
590                              ->getDataLayout()
591                              .getMaxPointerSizeInBits();
592   StackSafetyDataFlowAnalysis<GlobalValue> SSDFA(PointerSize, std::move(Copy));
593 
594   for (auto &F : SSDFA.run()) {
595     auto FI = F.second;
596     auto &SrcF = Functions[F.first];
597     for (auto &KV : FI.Allocas) {
598       auto &A = KV.second;
599       resolveAllCalls(A);
600       for (auto &C : A.Calls) {
601         A.updateRange(
602             SSDFA.getArgumentAccessRange(C.Callee, C.ParamNo, C.Offset));
603       }
604       // FIXME: This is needed only to preserve calls in print() results.
605       A.Calls = SrcF.Allocas.find(KV.first)->second.Calls;
606     }
607     for (auto &KV : FI.Params) {
608       auto &P = KV.second;
609       P.Calls = SrcF.Params.find(KV.first)->second.Calls;
610     }
611     SSI[F.first] = std::move(FI);
612   }
613 
614   return SSI;
615 }
616 
617 } // end anonymous namespace
618 
619 StackSafetyInfo::StackSafetyInfo() = default;
620 
621 StackSafetyInfo::StackSafetyInfo(Function *F,
622                                  std::function<ScalarEvolution &()> GetSE)
623     : F(F), GetSE(GetSE) {}
624 
625 StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default;
626 
627 StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default;
628 
629 StackSafetyInfo::~StackSafetyInfo() = default;
630 
631 const StackSafetyInfo::InfoTy &StackSafetyInfo::getInfo() const {
632   if (!Info) {
633     StackSafetyLocalAnalysis SSLA(*F, GetSE());
634     Info.reset(new InfoTy{SSLA.run()});
635   }
636   return *Info;
637 }
638 
639 void StackSafetyInfo::print(raw_ostream &O) const {
640   getInfo().Info.print(O, F->getName(), dyn_cast<Function>(F));
641 }
642 
643 const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
644   if (!Info) {
645     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions;
646     for (auto &F : M->functions()) {
647       if (!F.isDeclaration()) {
648         auto FI = GetSSI(F).getInfo().Info;
649         Functions.emplace(&F, std::move(FI));
650       }
651     }
652     Info.reset(new InfoTy{
653         createGlobalStackSafetyInfo(std::move(Functions), Index), {}});
654     for (auto &FnKV : Info->Info) {
655       for (auto &KV : FnKV.second.Allocas) {
656         ++NumAllocaTotal;
657         const AllocaInst *AI = KV.first;
658         if (getStaticAllocaSizeRange(*AI).contains(KV.second.Range)) {
659           Info->SafeAllocas.insert(AI);
660           ++NumAllocaStackSafe;
661         }
662       }
663     }
664     if (StackSafetyPrint)
665       print(errs());
666   }
667   return *Info;
668 }
669 
670 // Converts a StackSafetyFunctionInfo to the relevant FunctionSummary
671 // constructor fields
672 std::vector<FunctionSummary::ParamAccess>
673 StackSafetyInfo::getParamAccesses() const {
674   assert(needsParamAccessSummary(*F->getParent()));
675 
676   std::vector<FunctionSummary::ParamAccess> ParamAccesses;
677   for (const auto &KV : getInfo().Info.Params) {
678     auto &PS = KV.second;
679     if (PS.Range.isFullSet())
680       continue;
681 
682     ParamAccesses.emplace_back(KV.first, PS.Range);
683     FunctionSummary::ParamAccess &Param = ParamAccesses.back();
684 
685     Param.Calls.reserve(PS.Calls.size());
686     for (auto &C : PS.Calls) {
687       if (C.Offset.isFullSet()) {
688         ParamAccesses.pop_back();
689         break;
690       }
691       Param.Calls.emplace_back(C.ParamNo, C.Callee->getGUID(), C.Offset);
692     }
693   }
694   return ParamAccesses;
695 }
696 
697 StackSafetyGlobalInfo::StackSafetyGlobalInfo() = default;
698 
699 StackSafetyGlobalInfo::StackSafetyGlobalInfo(
700     Module *M, std::function<const StackSafetyInfo &(Function &F)> GetSSI,
701     const ModuleSummaryIndex *Index)
702     : M(M), GetSSI(GetSSI), Index(Index) {
703   if (StackSafetyRun)
704     getInfo();
705 }
706 
707 StackSafetyGlobalInfo::StackSafetyGlobalInfo(StackSafetyGlobalInfo &&) =
708     default;
709 
710 StackSafetyGlobalInfo &
711 StackSafetyGlobalInfo::operator=(StackSafetyGlobalInfo &&) = default;
712 
713 StackSafetyGlobalInfo::~StackSafetyGlobalInfo() = default;
714 
715 bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
716   const auto &Info = getInfo();
717   return Info.SafeAllocas.count(&AI);
718 }
719 
720 void StackSafetyGlobalInfo::print(raw_ostream &O) const {
721   auto &SSI = getInfo().Info;
722   if (SSI.empty())
723     return;
724   const Module &M = *SSI.begin()->first->getParent();
725   for (auto &F : M.functions()) {
726     if (!F.isDeclaration()) {
727       SSI.find(&F)->second.print(O, F.getName(), &F);
728       O << "\n";
729     }
730   }
731 }
732 
733 LLVM_DUMP_METHOD void StackSafetyGlobalInfo::dump() const { print(dbgs()); }
734 
735 AnalysisKey StackSafetyAnalysis::Key;
736 
737 StackSafetyInfo StackSafetyAnalysis::run(Function &F,
738                                          FunctionAnalysisManager &AM) {
739   return StackSafetyInfo(&F, [&AM, &F]() -> ScalarEvolution & {
740     return AM.getResult<ScalarEvolutionAnalysis>(F);
741   });
742 }
743 
744 PreservedAnalyses StackSafetyPrinterPass::run(Function &F,
745                                               FunctionAnalysisManager &AM) {
746   OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n";
747   AM.getResult<StackSafetyAnalysis>(F).print(OS);
748   return PreservedAnalyses::all();
749 }
750 
751 char StackSafetyInfoWrapperPass::ID = 0;
752 
753 StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) {
754   initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
755 }
756 
757 void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
758   AU.addRequiredTransitive<ScalarEvolutionWrapperPass>();
759   AU.setPreservesAll();
760 }
761 
762 void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const {
763   SSI.print(O);
764 }
765 
766 bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) {
767   auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
768   SSI = {&F, [SE]() -> ScalarEvolution & { return *SE; }};
769   return false;
770 }
771 
772 AnalysisKey StackSafetyGlobalAnalysis::Key;
773 
774 StackSafetyGlobalInfo
775 StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
776   // FIXME: Lookup Module Summary.
777   FunctionAnalysisManager &FAM =
778       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
779   return {&M,
780           [&FAM](Function &F) -> const StackSafetyInfo & {
781             return FAM.getResult<StackSafetyAnalysis>(F);
782           },
783           nullptr};
784 }
785 
786 PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
787                                                     ModuleAnalysisManager &AM) {
788   OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n";
789   AM.getResult<StackSafetyGlobalAnalysis>(M).print(OS);
790   return PreservedAnalyses::all();
791 }
792 
793 char StackSafetyGlobalInfoWrapperPass::ID = 0;
794 
795 StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
796     : ModulePass(ID) {
797   initializeStackSafetyGlobalInfoWrapperPassPass(
798       *PassRegistry::getPassRegistry());
799 }
800 
801 StackSafetyGlobalInfoWrapperPass::~StackSafetyGlobalInfoWrapperPass() = default;
802 
803 void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
804                                              const Module *M) const {
805   SSGI.print(O);
806 }
807 
808 void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
809     AnalysisUsage &AU) const {
810   AU.setPreservesAll();
811   AU.addRequired<StackSafetyInfoWrapperPass>();
812 }
813 
814 bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
815   const ModuleSummaryIndex *ImportSummary = nullptr;
816   if (auto *IndexWrapperPass =
817           getAnalysisIfAvailable<ImmutableModuleSummaryIndexWrapperPass>())
818     ImportSummary = IndexWrapperPass->getIndex();
819 
820   SSGI = {&M,
821           [this](Function &F) -> const StackSafetyInfo & {
822             return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
823           },
824           ImportSummary};
825   return false;
826 }
827 
828 bool llvm::needsParamAccessSummary(const Module &M) {
829   if (StackSafetyRun)
830     return true;
831   for (auto &F : M.functions())
832     if (F.hasFnAttribute(Attribute::SanitizeMemTag))
833       return true;
834   return false;
835 }
836 
837 static const char LocalPassArg[] = "stack-safety-local";
838 static const char LocalPassName[] = "Stack Safety Local Analysis";
839 INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
840                       false, true)
841 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
842 INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
843                     false, true)
844 
845 static const char GlobalPassName[] = "Stack Safety Analysis";
846 INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
847                       GlobalPassName, false, true)
848 INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass)
849 INITIALIZE_PASS_DEPENDENCY(ImmutableModuleSummaryIndexWrapperPass)
850 INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
851                     GlobalPassName, false, true)
852