1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/IPO/OpenMPOpt.h"
16 
17 #include "llvm/ADT/EnumeratedArray.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/CallGraphSCCPass.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Frontend/OpenMP/OMPConstants.h"
23 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Transforms/IPO.h"
27 #include "llvm/Transforms/IPO/Attributor.h"
28 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
29 
30 using namespace llvm;
31 using namespace omp;
32 using namespace types;
33 
34 #define DEBUG_TYPE "openmp-opt"
35 
36 static cl::opt<bool> DisableOpenMPOptimizations(
37     "openmp-opt-disable", cl::ZeroOrMore,
38     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
39     cl::init(false));
40 
41 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
42                                     cl::Hidden);
43 
44 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
45           "Number of OpenMP runtime calls deduplicated");
46 STATISTIC(NumOpenMPParallelRegionsDeleted,
47           "Number of OpenMP parallel regions deleted");
48 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
49           "Number of OpenMP runtime functions identified");
50 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
51           "Number of OpenMP runtime function uses identified");
52 
53 #if !defined(NDEBUG)
54 static constexpr auto TAG = "[" DEBUG_TYPE "]";
55 #endif
56 
57 namespace {
58 
59 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
60 /// Attributor runs.
61 struct OMPInformationCache : public InformationCache {
62   OMPInformationCache(Module &M, AnalysisGetter &AG,
63                       BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
64                       SmallPtrSetImpl<Function *> &ModuleSlice)
65       : InformationCache(M, AG, Allocator, CGSCC), ModuleSlice(ModuleSlice),
66         OMPBuilder(M) {
67     OMPBuilder.initialize();
68     initializeRuntimeFunctions();
69     initializeInternalControlVars();
70   }
71 
72   /// Generic information that describes an internal control variable.
73   struct InternalControlVarInfo {
74     /// The kind, as described by InternalControlVar enum.
75     InternalControlVar Kind;
76 
77     /// The name of the ICV.
78     StringRef Name;
79 
80     /// Environment variable associated with this ICV.
81     StringRef EnvVarName;
82 
83     /// Initial value kind.
84     ICVInitValue InitKind;
85 
86     /// Initial value.
87     ConstantInt *InitValue;
88 
89     /// Setter RTL function associated with this ICV.
90     RuntimeFunction Setter;
91 
92     /// Getter RTL function associated with this ICV.
93     RuntimeFunction Getter;
94 
95     /// RTL Function corresponding to the override clause of this ICV
96     RuntimeFunction Clause;
97   };
98 
99   /// Generic information that describes a runtime function
100   struct RuntimeFunctionInfo {
101 
102     /// The kind, as described by the RuntimeFunction enum.
103     RuntimeFunction Kind;
104 
105     /// The name of the function.
106     StringRef Name;
107 
108     /// Flag to indicate a variadic function.
109     bool IsVarArg;
110 
111     /// The return type of the function.
112     Type *ReturnType;
113 
114     /// The argument types of the function.
115     SmallVector<Type *, 8> ArgumentTypes;
116 
117     /// The declaration if available.
118     Function *Declaration = nullptr;
119 
120     /// Uses of this runtime function per function containing the use.
121     using UseVector = SmallVector<Use *, 16>;
122 
123     /// Return the vector of uses in function \p F.
124     UseVector &getOrCreateUseVector(Function *F) {
125       std::unique_ptr<UseVector> &UV = UsesMap[F];
126       if (!UV)
127         UV = std::make_unique<UseVector>();
128       return *UV;
129     }
130 
131     /// Return the vector of uses in function \p F or `nullptr` if there are
132     /// none.
133     const UseVector *getUseVector(Function &F) const {
134       auto I = UsesMap.find(&F);
135       if (I != UsesMap.end())
136         return I->second.get();
137       return nullptr;
138     }
139 
140     /// Return how many functions contain uses of this runtime function.
141     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
142 
143     /// Return the number of arguments (or the minimal number for variadic
144     /// functions).
145     size_t getNumArgs() const { return ArgumentTypes.size(); }
146 
147     /// Run the callback \p CB on each use and forget the use if the result is
148     /// true. The callback will be fed the function in which the use was
149     /// encountered as second argument.
150     void foreachUse(function_ref<bool(Use &, Function &)> CB) {
151       for (auto &It : UsesMap)
152         foreachUse(CB, It.first, It.second.get());
153     }
154 
155     /// Run the callback \p CB on each use within the function \p F and forget
156     /// the use if the result is true.
157     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F,
158                     UseVector *Uses = nullptr) {
159       SmallVector<unsigned, 8> ToBeDeleted;
160       ToBeDeleted.clear();
161 
162       unsigned Idx = 0;
163       UseVector &UV = Uses ? *Uses : getOrCreateUseVector(F);
164 
165       for (Use *U : UV) {
166         if (CB(*U, *F))
167           ToBeDeleted.push_back(Idx);
168         ++Idx;
169       }
170 
171       // Remove the to-be-deleted indices in reverse order as prior
172       // modifcations will not modify the smaller indices.
173       while (!ToBeDeleted.empty()) {
174         unsigned Idx = ToBeDeleted.pop_back_val();
175         UV[Idx] = UV.back();
176         UV.pop_back();
177       }
178     }
179 
180   private:
181     /// Map from functions to all uses of this runtime function contained in
182     /// them.
183     DenseMap<Function *, std::unique_ptr<UseVector>> UsesMap;
184   };
185 
186   /// The slice of the module we are allowed to look at.
187   SmallPtrSetImpl<Function *> &ModuleSlice;
188 
189   /// An OpenMP-IR-Builder instance
190   OpenMPIRBuilder OMPBuilder;
191 
192   /// Map from runtime function kind to the runtime function description.
193   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
194                   RuntimeFunction::OMPRTL___last>
195       RFIs;
196 
197   /// Map from ICV kind to the ICV description.
198   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
199                   InternalControlVar::ICV___last>
200       ICVs;
201 
202   /// Helper to initialize all internal control variable information for those
203   /// defined in OMPKinds.def.
204   void initializeInternalControlVars() {
205 #define ICV_RT_SET(_Name, RTL)                                                 \
206   {                                                                            \
207     auto &ICV = ICVs[_Name];                                                   \
208     ICV.Setter = RTL;                                                          \
209   }
210 #define ICV_RT_GET(Name, RTL)                                                  \
211   {                                                                            \
212     auto &ICV = ICVs[Name];                                                    \
213     ICV.Getter = RTL;                                                          \
214   }
215 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
216   {                                                                            \
217     auto &ICV = ICVs[Enum];                                                    \
218     ICV.Name = _Name;                                                          \
219     ICV.Kind = Enum;                                                           \
220     ICV.InitKind = Init;                                                       \
221     ICV.EnvVarName = _EnvVarName;                                              \
222     switch (ICV.InitKind) {                                                    \
223     case ICV_IMPLEMENTATION_DEFINED:                                           \
224       ICV.InitValue = nullptr;                                                 \
225       break;                                                                   \
226     case ICV_ZERO:                                                             \
227       ICV.InitValue =                                                          \
228           ConstantInt::get(Type::getInt32Ty(Int32->getContext()), 0);          \
229       break;                                                                   \
230     case ICV_FALSE:                                                            \
231       ICV.InitValue = ConstantInt::getFalse(Int1->getContext());               \
232       break;                                                                   \
233     case ICV_LAST:                                                             \
234       break;                                                                   \
235     }                                                                          \
236   }
237 #include "llvm/Frontend/OpenMP/OMPKinds.def"
238   }
239 
240   /// Returns true if the function declaration \p F matches the runtime
241   /// function types, that is, return type \p RTFRetType, and argument types
242   /// \p RTFArgTypes.
243   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
244                                   SmallVector<Type *, 8> &RTFArgTypes) {
245     // TODO: We should output information to the user (under debug output
246     //       and via remarks).
247 
248     if (!F)
249       return false;
250     if (F->getReturnType() != RTFRetType)
251       return false;
252     if (F->arg_size() != RTFArgTypes.size())
253       return false;
254 
255     auto RTFTyIt = RTFArgTypes.begin();
256     for (Argument &Arg : F->args()) {
257       if (Arg.getType() != *RTFTyIt)
258         return false;
259 
260       ++RTFTyIt;
261     }
262 
263     return true;
264   }
265 
266   /// Helper to initialize all runtime function information for those defined
267   /// in OpenMPKinds.def.
268   void initializeRuntimeFunctions() {
269     // Helper to collect all uses of the decleration in the UsesMap.
270     auto CollectUses = [&](RuntimeFunctionInfo &RFI) {
271       unsigned NumUses = 0;
272       if (!RFI.Declaration)
273         return NumUses;
274       OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
275 
276       NumOpenMPRuntimeFunctionsIdentified += 1;
277       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
278 
279       // TODO: We directly convert uses into proper calls and unknown uses.
280       for (Use &U : RFI.Declaration->uses()) {
281         if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
282           if (ModuleSlice.count(UserI->getFunction())) {
283             RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
284             ++NumUses;
285           }
286         } else {
287           RFI.getOrCreateUseVector(nullptr).push_back(&U);
288           ++NumUses;
289         }
290       }
291       return NumUses;
292     };
293 
294     Module &M = *((*ModuleSlice.begin())->getParent());
295 
296 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
297   {                                                                            \
298     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
299     Function *F = M.getFunction(_Name);                                        \
300     if (declMatchesRTFTypes(F, _ReturnType, ArgsTypes)) {                      \
301       auto &RFI = RFIs[_Enum];                                                 \
302       RFI.Kind = _Enum;                                                        \
303       RFI.Name = _Name;                                                        \
304       RFI.IsVarArg = _IsVarArg;                                                \
305       RFI.ReturnType = _ReturnType;                                            \
306       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
307       RFI.Declaration = F;                                                     \
308       unsigned NumUses = CollectUses(RFI);                                     \
309       (void)NumUses;                                                           \
310       LLVM_DEBUG({                                                             \
311         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
312                << " found\n";                                                  \
313         if (RFI.Declaration)                                                   \
314           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
315                  << RFI.getNumFunctionsWithUses()                              \
316                  << " different functions.\n";                                 \
317       });                                                                      \
318     }                                                                          \
319   }
320 #include "llvm/Frontend/OpenMP/OMPKinds.def"
321 
322     // TODO: We should attach the attributes defined in OMPKinds.def.
323   }
324 };
325 
326 struct OpenMPOpt {
327 
328   using OptimizationRemarkGetter =
329       function_ref<OptimizationRemarkEmitter &(Function *)>;
330 
331   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
332             OptimizationRemarkGetter OREGetter,
333             OMPInformationCache &OMPInfoCache)
334       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
335         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache) {}
336 
337   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
338   bool run() {
339     bool Changed = false;
340 
341     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
342                       << " functions in a slice with "
343                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
344 
345     /// Print initial ICV values for testing.
346     /// FIXME: This should be done from the Attributor once it is added.
347     if (PrintICVValues) {
348       InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel};
349 
350       for (Function *F : OMPInfoCache.ModuleSlice) {
351         for (auto ICV : ICVs) {
352           auto ICVInfo = OMPInfoCache.ICVs[ICV];
353           auto Remark = [&](OptimizationRemark OR) {
354             return OR << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
355                       << " Value: "
356                       << (ICVInfo.InitValue
357                               ? ICVInfo.InitValue->getValue().toString(10, true)
358                               : "IMPLEMENTATION_DEFINED");
359           };
360 
361           emitRemarkOnFunction(F, "OpenMPICVTracker", Remark);
362         }
363       }
364     }
365 
366     Changed |= deduplicateRuntimeCalls();
367     Changed |= deleteParallelRegions();
368 
369     return Changed;
370   }
371 
372   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
373   /// given it has to be the callee or a nullptr is returned.
374   static CallInst *getCallIfRegularCall(
375       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
376     CallInst *CI = dyn_cast<CallInst>(U.getUser());
377     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
378         (!RFI || CI->getCalledFunction() == RFI->Declaration))
379       return CI;
380     return nullptr;
381   }
382 
383   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
384   /// the callee or a nullptr is returned.
385   static CallInst *getCallIfRegularCall(
386       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
387     CallInst *CI = dyn_cast<CallInst>(&V);
388     if (CI && !CI->hasOperandBundles() &&
389         (!RFI || CI->getCalledFunction() == RFI->Declaration))
390       return CI;
391     return nullptr;
392   }
393 
394 private:
395   /// Try to delete parallel regions if possible.
396   bool deleteParallelRegions() {
397     const unsigned CallbackCalleeOperand = 2;
398 
399     OMPInformationCache::RuntimeFunctionInfo &RFI =
400         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
401 
402     if (!RFI.Declaration)
403       return false;
404 
405     bool Changed = false;
406     auto DeleteCallCB = [&](Use &U, Function &) {
407       CallInst *CI = getCallIfRegularCall(U);
408       if (!CI)
409         return false;
410       auto *Fn = dyn_cast<Function>(
411           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
412       if (!Fn)
413         return false;
414       if (!Fn->onlyReadsMemory())
415         return false;
416       if (!Fn->hasFnAttribute(Attribute::WillReturn))
417         return false;
418 
419       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
420                         << CI->getCaller()->getName() << "\n");
421 
422       auto Remark = [&](OptimizationRemark OR) {
423         return OR << "Parallel region in "
424                   << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName())
425                   << " deleted";
426       };
427       emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion",
428                                      Remark);
429 
430       CGUpdater.removeCallSite(*CI);
431       CI->eraseFromParent();
432       Changed = true;
433       ++NumOpenMPParallelRegionsDeleted;
434       return true;
435     };
436 
437     RFI.foreachUse(DeleteCallCB);
438 
439     return Changed;
440   }
441 
442   /// Try to eliminiate runtime calls by reusing existing ones.
443   bool deduplicateRuntimeCalls() {
444     bool Changed = false;
445 
446     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
447         OMPRTL_omp_get_num_threads,
448         OMPRTL_omp_in_parallel,
449         OMPRTL_omp_get_cancellation,
450         OMPRTL_omp_get_thread_limit,
451         OMPRTL_omp_get_supported_active_levels,
452         OMPRTL_omp_get_level,
453         OMPRTL_omp_get_ancestor_thread_num,
454         OMPRTL_omp_get_team_size,
455         OMPRTL_omp_get_active_level,
456         OMPRTL_omp_in_final,
457         OMPRTL_omp_get_proc_bind,
458         OMPRTL_omp_get_num_places,
459         OMPRTL_omp_get_num_procs,
460         OMPRTL_omp_get_place_num,
461         OMPRTL_omp_get_partition_num_places,
462         OMPRTL_omp_get_partition_place_nums};
463 
464     // Global-tid is handled separately.
465     SmallSetVector<Value *, 16> GTIdArgs;
466     collectGlobalThreadIdArguments(GTIdArgs);
467     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
468                       << " global thread ID arguments\n");
469 
470     for (Function *F : SCC) {
471       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
472         deduplicateRuntimeCalls(*F,
473                                 OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
474 
475       // __kmpc_global_thread_num is special as we can replace it with an
476       // argument in enough cases to make it worth trying.
477       Value *GTIdArg = nullptr;
478       for (Argument &Arg : F->args())
479         if (GTIdArgs.count(&Arg)) {
480           GTIdArg = &Arg;
481           break;
482         }
483       Changed |= deduplicateRuntimeCalls(
484           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
485     }
486 
487     return Changed;
488   }
489 
490   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
491                                     bool GlobalOnly, bool &SingleChoice) {
492     if (CurrentIdent == NextIdent)
493       return CurrentIdent;
494 
495     // TODO: Figure out how to actually combine multiple debug locations. For
496     //       now we just keep an existing one if there is a single choice.
497     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
498       SingleChoice = !CurrentIdent;
499       return NextIdent;
500     }
501     return nullptr;
502   }
503 
504   /// Return an `struct ident_t*` value that represents the ones used in the
505   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
506   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
507   /// return value we create one from scratch. We also do not yet combine
508   /// information, e.g., the source locations, see combinedIdentStruct.
509   Value *
510   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
511                                  Function &F, bool GlobalOnly) {
512     bool SingleChoice = true;
513     Value *Ident = nullptr;
514     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
515       CallInst *CI = getCallIfRegularCall(U, &RFI);
516       if (!CI || &F != &Caller)
517         return false;
518       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
519                                   /* GlobalOnly */ true, SingleChoice);
520       return false;
521     };
522     RFI.foreachUse(CombineIdentStruct);
523 
524     if (!Ident || !SingleChoice) {
525       // The IRBuilder uses the insertion block to get to the module, this is
526       // unfortunate but we work around it for now.
527       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
528         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
529             &F.getEntryBlock(), F.getEntryBlock().begin()));
530       // Create a fallback location if non was found.
531       // TODO: Use the debug locations of the calls instead.
532       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
533       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
534     }
535     return Ident;
536   }
537 
538   /// Try to eliminiate calls of \p RFI in \p F by reusing an existing one or
539   /// \p ReplVal if given.
540   bool deduplicateRuntimeCalls(Function &F,
541                                OMPInformationCache::RuntimeFunctionInfo &RFI,
542                                Value *ReplVal = nullptr) {
543     auto *UV = RFI.getUseVector(F);
544     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
545       return false;
546 
547     LLVM_DEBUG(
548         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
549                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
550 
551     assert((!ReplVal || (isa<Argument>(ReplVal) &&
552                          cast<Argument>(ReplVal)->getParent() == &F)) &&
553            "Unexpected replacement value!");
554 
555     // TODO: Use dominance to find a good position instead.
556     auto CanBeMoved = [](CallBase &CB) {
557       unsigned NumArgs = CB.getNumArgOperands();
558       if (NumArgs == 0)
559         return true;
560       if (CB.getArgOperand(0)->getType() != IdentPtr)
561         return false;
562       for (unsigned u = 1; u < NumArgs; ++u)
563         if (isa<Instruction>(CB.getArgOperand(u)))
564           return false;
565       return true;
566     };
567 
568     if (!ReplVal) {
569       for (Use *U : *UV)
570         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
571           if (!CanBeMoved(*CI))
572             continue;
573 
574           auto Remark = [&](OptimizationRemark OR) {
575             auto newLoc = &*F.getEntryBlock().getFirstInsertionPt();
576             return OR << "OpenMP runtime call "
577                       << ore::NV("OpenMPOptRuntime", RFI.Name) << " moved to "
578                       << ore::NV("OpenMPRuntimeMoves", newLoc->getDebugLoc());
579           };
580           emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeCodeMotion", Remark);
581 
582           CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
583           ReplVal = CI;
584           break;
585         }
586       if (!ReplVal)
587         return false;
588     }
589 
590     // If we use a call as a replacement value we need to make sure the ident is
591     // valid at the new location. For now we just pick a global one, either
592     // existing and used by one of the calls, or created from scratch.
593     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
594       if (CI->getNumArgOperands() > 0 &&
595           CI->getArgOperand(0)->getType() == IdentPtr) {
596         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
597                                                       /* GlobalOnly */ true);
598         CI->setArgOperand(0, Ident);
599       }
600     }
601 
602     bool Changed = false;
603     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
604       CallInst *CI = getCallIfRegularCall(U, &RFI);
605       if (!CI || CI == ReplVal || &F != &Caller)
606         return false;
607       assert(CI->getCaller() == &F && "Unexpected call!");
608 
609       auto Remark = [&](OptimizationRemark OR) {
610         return OR << "OpenMP runtime call "
611                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated";
612       };
613       emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeDeduplicated", Remark);
614 
615       CGUpdater.removeCallSite(*CI);
616       CI->replaceAllUsesWith(ReplVal);
617       CI->eraseFromParent();
618       ++NumOpenMPRuntimeCallsDeduplicated;
619       Changed = true;
620       return true;
621     };
622     RFI.foreachUse(ReplaceAndDeleteCB);
623 
624     return Changed;
625   }
626 
627   /// Collect arguments that represent the global thread id in \p GTIdArgs.
628   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
629     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
630     //       initialization. We could define an AbstractAttribute instead and
631     //       run the Attributor here once it can be run as an SCC pass.
632 
633     // Helper to check the argument \p ArgNo at all call sites of \p F for
634     // a GTId.
635     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
636       if (!F.hasLocalLinkage())
637         return false;
638       for (Use &U : F.uses()) {
639         if (CallInst *CI = getCallIfRegularCall(U)) {
640           Value *ArgOp = CI->getArgOperand(ArgNo);
641           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
642               getCallIfRegularCall(
643                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
644             continue;
645         }
646         return false;
647       }
648       return true;
649     };
650 
651     // Helper to identify uses of a GTId as GTId arguments.
652     auto AddUserArgs = [&](Value &GTId) {
653       for (Use &U : GTId.uses())
654         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
655           if (CI->isArgOperand(&U))
656             if (Function *Callee = CI->getCalledFunction())
657               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
658                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
659     };
660 
661     // The argument users of __kmpc_global_thread_num calls are GTIds.
662     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
663         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
664 
665     GlobThreadNumRFI.foreachUse([&](Use &U, Function &F) {
666       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
667         AddUserArgs(*CI);
668       return false;
669     });
670 
671     // Transitively search for more arguments by looking at the users of the
672     // ones we know already. During the search the GTIdArgs vector is extended
673     // so we cannot cache the size nor can we use a range based for.
674     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
675       AddUserArgs(*GTIdArgs[u]);
676   }
677 
678   /// Emit a remark generically
679   ///
680   /// This template function can be used to generically emit a remark. The
681   /// RemarkKind should be one of the following:
682   ///   - OptimizationRemark to indicate a successful optimization attempt
683   ///   - OptimizationRemarkMissed to report a failed optimization attempt
684   ///   - OptimizationRemarkAnalysis to provide additional information about an
685   ///     optimization attempt
686   ///
687   /// The remark is built using a callback function provided by the caller that
688   /// takes a RemarkKind as input and returns a RemarkKind.
689   template <typename RemarkKind,
690             typename RemarkCallBack = function_ref<RemarkKind(RemarkKind &&)>>
691   void emitRemark(Instruction *Inst, StringRef RemarkName,
692                   RemarkCallBack &&RemarkCB) {
693     Function *F = Inst->getParent()->getParent();
694     auto &ORE = OREGetter(F);
695 
696     ORE.emit(
697         [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, Inst)); });
698   }
699 
700   /// Emit a remark on a function. Since only OptimizationRemark is supporting
701   /// this, it can't be made generic.
702   void emitRemarkOnFunction(
703       Function *F, StringRef RemarkName,
704       function_ref<OptimizationRemark(OptimizationRemark &&)> &&RemarkCB) {
705     auto &ORE = OREGetter(F);
706 
707     ORE.emit([&]() {
708       return RemarkCB(OptimizationRemark(DEBUG_TYPE, RemarkName, F));
709     });
710   }
711 
712   /// The underyling module.
713   Module &M;
714 
715   /// The SCC we are operating on.
716   SmallVectorImpl<Function *> &SCC;
717 
718   /// Callback to update the call graph, the first argument is a removed call,
719   /// the second an optional replacement call.
720   CallGraphUpdater &CGUpdater;
721 
722   /// Callback to get an OptimizationRemarkEmitter from a Function *
723   OptimizationRemarkGetter OREGetter;
724 
725   /// OpenMP-specific information cache. Also Used for Attributor runs.
726   OMPInformationCache &OMPInfoCache;
727 };
728 } // namespace
729 
730 PreservedAnalyses OpenMPOptPass::run(LazyCallGraph::SCC &C,
731                                      CGSCCAnalysisManager &AM,
732                                      LazyCallGraph &CG, CGSCCUpdateResult &UR) {
733   if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule))
734     return PreservedAnalyses::all();
735 
736   if (DisableOpenMPOptimizations)
737     return PreservedAnalyses::all();
738 
739   SmallPtrSet<Function *, 16> ModuleSlice;
740   SmallVector<Function *, 16> SCC;
741   for (LazyCallGraph::Node &N : C) {
742     SCC.push_back(&N.getFunction());
743     ModuleSlice.insert(SCC.back());
744   }
745 
746   if (SCC.empty())
747     return PreservedAnalyses::all();
748 
749   FunctionAnalysisManager &FAM =
750       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
751 
752   AnalysisGetter AG(FAM);
753 
754   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
755     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
756   };
757 
758   CallGraphUpdater CGUpdater;
759   CGUpdater.initialize(CG, C, AM, UR);
760 
761   SetVector<Function *> Functions(SCC.begin(), SCC.end());
762   BumpPtrAllocator Allocator;
763   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
764                                 /*CGSCC*/ &Functions, ModuleSlice);
765 
766   // TODO: Compute the module slice we are allowed to look at.
767   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache);
768   bool Changed = OMPOpt.run();
769   (void)Changed;
770   return PreservedAnalyses::all();
771 }
772 
773 namespace {
774 
775 struct OpenMPOptLegacyPass : public CallGraphSCCPass {
776   CallGraphUpdater CGUpdater;
777   OpenMPInModule OMPInModule;
778   static char ID;
779 
780   OpenMPOptLegacyPass() : CallGraphSCCPass(ID) {
781     initializeOpenMPOptLegacyPassPass(*PassRegistry::getPassRegistry());
782   }
783 
784   void getAnalysisUsage(AnalysisUsage &AU) const override {
785     CallGraphSCCPass::getAnalysisUsage(AU);
786   }
787 
788   bool doInitialization(CallGraph &CG) override {
789     // Disable the pass if there is no OpenMP (runtime call) in the module.
790     containsOpenMP(CG.getModule(), OMPInModule);
791     return false;
792   }
793 
794   bool runOnSCC(CallGraphSCC &CGSCC) override {
795     if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule))
796       return false;
797     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
798       return false;
799 
800     SmallPtrSet<Function *, 16> ModuleSlice;
801     SmallVector<Function *, 16> SCC;
802     for (CallGraphNode *CGN : CGSCC)
803       if (Function *Fn = CGN->getFunction())
804         if (!Fn->isDeclaration()) {
805           SCC.push_back(Fn);
806           ModuleSlice.insert(Fn);
807         }
808 
809     if (SCC.empty())
810       return false;
811 
812     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
813     CGUpdater.initialize(CG, CGSCC);
814 
815     // Maintain a map of functions to avoid rebuilding the ORE
816     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
817     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
818       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
819       if (!ORE)
820         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
821       return *ORE;
822     };
823 
824     AnalysisGetter AG;
825     SetVector<Function *> Functions(SCC.begin(), SCC.end());
826     BumpPtrAllocator Allocator;
827     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
828                                   Allocator,
829                                   /*CGSCC*/ &Functions, ModuleSlice);
830 
831     // TODO: Compute the module slice we are allowed to look at.
832     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache);
833     return OMPOpt.run();
834   }
835 
836   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
837 };
838 
839 } // end anonymous namespace
840 
841 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) {
842   if (OMPInModule.isKnown())
843     return OMPInModule;
844 
845 #define OMP_RTL(_Enum, _Name, ...)                                             \
846   if (M.getFunction(_Name))                                                    \
847     return OMPInModule = true;
848 #include "llvm/Frontend/OpenMP/OMPKinds.def"
849   return OMPInModule = false;
850 }
851 
852 char OpenMPOptLegacyPass::ID = 0;
853 
854 INITIALIZE_PASS_BEGIN(OpenMPOptLegacyPass, "openmpopt",
855                       "OpenMP specific optimizations", false, false)
856 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
857 INITIALIZE_PASS_END(OpenMPOptLegacyPass, "openmpopt",
858                     "OpenMP specific optimizations", false, false)
859 
860 Pass *llvm::createOpenMPOptLegacyPass() { return new OpenMPOptLegacyPass(); }
861