1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/IPO/OpenMPOpt.h"
16 
17 #include "llvm/ADT/EnumeratedArray.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/CallGraphSCCPass.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Frontend/OpenMP/OMPConstants.h"
23 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Transforms/IPO.h"
27 #include "llvm/Transforms/IPO/Attributor.h"
28 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
29 
30 using namespace llvm;
31 using namespace omp;
32 using namespace types;
33 
34 #define DEBUG_TYPE "openmp-opt"
35 
36 static cl::opt<bool> DisableOpenMPOptimizations(
37     "openmp-opt-disable", cl::ZeroOrMore,
38     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
39     cl::init(false));
40 
41 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
42           "Number of OpenMP runtime calls deduplicated");
43 STATISTIC(NumOpenMPParallelRegionsDeleted,
44           "Number of OpenMP parallel regions deleted");
45 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
46           "Number of OpenMP runtime functions identified");
47 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
48           "Number of OpenMP runtime function uses identified");
49 
50 #if !defined(NDEBUG)
51 static constexpr auto TAG = "[" DEBUG_TYPE "]";
52 #endif
53 
54 namespace {
55 
56 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
57 /// Attributor runs.
58 struct OMPInformationCache : public InformationCache {
59   OMPInformationCache(Module &M, AnalysisGetter &AG,
60                       BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
61                       SmallPtrSetImpl<Function *> &ModuleSlice)
62       : InformationCache(M, AG, Allocator, CGSCC), ModuleSlice(ModuleSlice),
63         OMPBuilder(M) {
64     initializeTypes(M);
65     initializeRuntimeFunctions();
66 
67     OMPBuilder.initialize();
68   }
69 
70   /// Generic information that describes a runtime function
71   struct RuntimeFunctionInfo {
72 
73     /// The kind, as described by the RuntimeFunction enum.
74     RuntimeFunction Kind;
75 
76     /// The name of the function.
77     StringRef Name;
78 
79     /// Flag to indicate a variadic function.
80     bool IsVarArg;
81 
82     /// The return type of the function.
83     Type *ReturnType;
84 
85     /// The argument types of the function.
86     SmallVector<Type *, 8> ArgumentTypes;
87 
88     /// The declaration if available.
89     Function *Declaration = nullptr;
90 
91     /// Uses of this runtime function per function containing the use.
92     using UseVector = SmallVector<Use *, 16>;
93 
94     /// Return the vector of uses in function \p F.
95     UseVector &getOrCreateUseVector(Function *F) {
96       std::unique_ptr<UseVector> &UV = UsesMap[F];
97       if (!UV)
98         UV = std::make_unique<UseVector>();
99       return *UV;
100     }
101 
102     /// Return the vector of uses in function \p F or `nullptr` if there are
103     /// none.
104     const UseVector *getUseVector(Function &F) const {
105       auto I = UsesMap.find(&F);
106       if (I != UsesMap.end())
107         return I->second.get();
108       return nullptr;
109     }
110 
111     /// Return how many functions contain uses of this runtime function.
112     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
113 
114     /// Return the number of arguments (or the minimal number for variadic
115     /// functions).
116     size_t getNumArgs() const { return ArgumentTypes.size(); }
117 
118     /// Run the callback \p CB on each use and forget the use if the result is
119     /// true. The callback will be fed the function in which the use was
120     /// encountered as second argument.
121     void foreachUse(function_ref<bool(Use &, Function &)> CB) {
122       for (auto &It : UsesMap)
123         foreachUse(CB, It.first, It.second.get());
124     }
125 
126     /// Run the callback \p CB on each use within the function \p F and forget
127     /// the use if the result is true.
128     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F,
129                     UseVector *Uses = nullptr) {
130       SmallVector<unsigned, 8> ToBeDeleted;
131       ToBeDeleted.clear();
132 
133       unsigned Idx = 0;
134       UseVector &UV = Uses ? *Uses : getOrCreateUseVector(F);
135 
136       for (Use *U : UV) {
137         if (CB(*U, *F))
138           ToBeDeleted.push_back(Idx);
139         ++Idx;
140       }
141 
142       // Remove the to-be-deleted indices in reverse order as prior
143       // modifcations will not modify the smaller indices.
144       while (!ToBeDeleted.empty()) {
145         unsigned Idx = ToBeDeleted.pop_back_val();
146         UV[Idx] = UV.back();
147         UV.pop_back();
148       }
149     }
150 
151   private:
152     /// Map from functions to all uses of this runtime function contained in
153     /// them.
154     DenseMap<Function *, std::unique_ptr<UseVector>> UsesMap;
155   };
156 
157   /// The slice of the module we are allowed to look at.
158   SmallPtrSetImpl<Function *> &ModuleSlice;
159 
160   /// An OpenMP-IR-Builder instance
161   OpenMPIRBuilder OMPBuilder;
162 
163   /// Map from runtime function kind to the runtime function description.
164   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
165                   RuntimeFunction::OMPRTL___last>
166       RFIs;
167 
168   /// Returns true if the function declaration \p F matches the runtime
169   /// function types, that is, return type \p RTFRetType, and argument types
170   /// \p RTFArgTypes.
171   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
172                                   SmallVector<Type *, 8> &RTFArgTypes) {
173     // TODO: We should output information to the user (under debug output
174     //       and via remarks).
175 
176     if (!F)
177       return false;
178     if (F->getReturnType() != RTFRetType)
179       return false;
180     if (F->arg_size() != RTFArgTypes.size())
181       return false;
182 
183     auto RTFTyIt = RTFArgTypes.begin();
184     for (Argument &Arg : F->args()) {
185       if (Arg.getType() != *RTFTyIt)
186         return false;
187 
188       ++RTFTyIt;
189     }
190 
191     return true;
192   }
193 
194   /// Helper to initialize all runtime function information for those defined
195   /// in OpenMPKinds.def.
196   void initializeRuntimeFunctions() {
197     // Helper to collect all uses of the decleration in the UsesMap.
198     auto CollectUses = [&](RuntimeFunctionInfo &RFI) {
199       unsigned NumUses = 0;
200       if (!RFI.Declaration)
201         return NumUses;
202       OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
203 
204       NumOpenMPRuntimeFunctionsIdentified += 1;
205       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
206 
207       // TODO: We directly convert uses into proper calls and unknown uses.
208       for (Use &U : RFI.Declaration->uses()) {
209         if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
210           if (ModuleSlice.count(UserI->getFunction())) {
211             RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
212             ++NumUses;
213           }
214         } else {
215           RFI.getOrCreateUseVector(nullptr).push_back(&U);
216           ++NumUses;
217         }
218       }
219       return NumUses;
220     };
221 
222     Module &M = *((*ModuleSlice.begin())->getParent());
223 
224 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
225   {                                                                            \
226     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
227     Function *F = M.getFunction(_Name);                                        \
228     if (declMatchesRTFTypes(F, _ReturnType, ArgsTypes)) {                      \
229       auto &RFI = RFIs[_Enum];                                                 \
230       RFI.Kind = _Enum;                                                        \
231       RFI.Name = _Name;                                                        \
232       RFI.IsVarArg = _IsVarArg;                                                \
233       RFI.ReturnType = _ReturnType;                                            \
234       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
235       RFI.Declaration = F;                                                     \
236       unsigned NumUses = CollectUses(RFI);                                     \
237       (void)NumUses;                                                           \
238       LLVM_DEBUG({                                                             \
239         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
240                << " found\n";                                                  \
241         if (RFI.Declaration)                                                   \
242           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
243                  << RFI.getNumFunctionsWithUses()                              \
244                  << " different functions.\n";                                 \
245       });                                                                      \
246     }                                                                          \
247   }
248 #include "llvm/Frontend/OpenMP/OMPKinds.def"
249 
250     // TODO: We should attach the attributes defined in OMPKinds.def.
251   }
252 };
253 
254 struct OpenMPOpt {
255 
256   using OptimizationRemarkGetter =
257       function_ref<OptimizationRemarkEmitter &(Function *)>;
258 
259   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
260             OptimizationRemarkGetter OREGetter,
261             OMPInformationCache &OMPInfoCache)
262       : M(*(*SCC.begin())->getParent()), SCC(SCC),
263         ModuleSlice(OMPInfoCache.ModuleSlice), CGUpdater(CGUpdater),
264         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache) {}
265 
266   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
267   bool run() {
268     bool Changed = false;
269 
270     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
271                       << " functions in a slice with " << ModuleSlice.size()
272                       << " functions\n");
273 
274     Changed |= deduplicateRuntimeCalls();
275     Changed |= deleteParallelRegions();
276 
277     return Changed;
278   }
279 
280   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
281   /// given it has to be the callee or a nullptr is returned.
282   static CallInst *getCallIfRegularCall(
283       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
284     CallInst *CI = dyn_cast<CallInst>(U.getUser());
285     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
286         (!RFI || CI->getCalledFunction() == RFI->Declaration))
287       return CI;
288     return nullptr;
289   }
290 
291   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
292   /// the callee or a nullptr is returned.
293   static CallInst *getCallIfRegularCall(
294       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
295     CallInst *CI = dyn_cast<CallInst>(&V);
296     if (CI && !CI->hasOperandBundles() &&
297         (!RFI || CI->getCalledFunction() == RFI->Declaration))
298       return CI;
299     return nullptr;
300   }
301 
302 private:
303   /// Try to delete parallel regions if possible.
304   bool deleteParallelRegions() {
305     const unsigned CallbackCalleeOperand = 2;
306 
307     OMPInformationCache::RuntimeFunctionInfo &RFI =
308         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
309 
310     if (!RFI.Declaration)
311       return false;
312 
313     bool Changed = false;
314     auto DeleteCallCB = [&](Use &U, Function &) {
315       CallInst *CI = getCallIfRegularCall(U);
316       if (!CI)
317         return false;
318       auto *Fn = dyn_cast<Function>(
319           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
320       if (!Fn)
321         return false;
322       if (!Fn->onlyReadsMemory())
323         return false;
324       if (!Fn->hasFnAttribute(Attribute::WillReturn))
325         return false;
326 
327       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
328                         << CI->getCaller()->getName() << "\n");
329 
330       auto Remark = [&](OptimizationRemark OR) {
331         return OR << "Parallel region in "
332                   << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName())
333                   << " deleted";
334       };
335       emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion",
336                                      Remark);
337 
338       CGUpdater.removeCallSite(*CI);
339       CI->eraseFromParent();
340       Changed = true;
341       ++NumOpenMPParallelRegionsDeleted;
342       return true;
343     };
344 
345     RFI.foreachUse(DeleteCallCB);
346 
347     return Changed;
348   }
349 
350   /// Try to eliminiate runtime calls by reusing existing ones.
351   bool deduplicateRuntimeCalls() {
352     bool Changed = false;
353 
354     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
355         OMPRTL_omp_get_num_threads,
356         OMPRTL_omp_in_parallel,
357         OMPRTL_omp_get_cancellation,
358         OMPRTL_omp_get_thread_limit,
359         OMPRTL_omp_get_supported_active_levels,
360         OMPRTL_omp_get_level,
361         OMPRTL_omp_get_ancestor_thread_num,
362         OMPRTL_omp_get_team_size,
363         OMPRTL_omp_get_active_level,
364         OMPRTL_omp_in_final,
365         OMPRTL_omp_get_proc_bind,
366         OMPRTL_omp_get_num_places,
367         OMPRTL_omp_get_num_procs,
368         OMPRTL_omp_get_place_num,
369         OMPRTL_omp_get_partition_num_places,
370         OMPRTL_omp_get_partition_place_nums};
371 
372     // Global-tid is handled separately.
373     SmallSetVector<Value *, 16> GTIdArgs;
374     collectGlobalThreadIdArguments(GTIdArgs);
375     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
376                       << " global thread ID arguments\n");
377 
378     for (Function *F : SCC) {
379       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
380         deduplicateRuntimeCalls(*F,
381                                 OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
382 
383       // __kmpc_global_thread_num is special as we can replace it with an
384       // argument in enough cases to make it worth trying.
385       Value *GTIdArg = nullptr;
386       for (Argument &Arg : F->args())
387         if (GTIdArgs.count(&Arg)) {
388           GTIdArg = &Arg;
389           break;
390         }
391       Changed |= deduplicateRuntimeCalls(
392           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
393     }
394 
395     return Changed;
396   }
397 
398   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
399                                     bool GlobalOnly, bool &SingleChoice) {
400     if (CurrentIdent == NextIdent)
401       return CurrentIdent;
402 
403     // TODO: Figure out how to actually combine multiple debug locations. For
404     //       now we just keep an existing one if there is a single choice.
405     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
406       SingleChoice = !CurrentIdent;
407       return NextIdent;
408     }
409     return nullptr;
410   }
411 
412   /// Return an `struct ident_t*` value that represents the ones used in the
413   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
414   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
415   /// return value we create one from scratch. We also do not yet combine
416   /// information, e.g., the source locations, see combinedIdentStruct.
417   Value *
418   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
419                                  Function &F, bool GlobalOnly) {
420     bool SingleChoice = true;
421     Value *Ident = nullptr;
422     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
423       CallInst *CI = getCallIfRegularCall(U, &RFI);
424       if (!CI || &F != &Caller)
425         return false;
426       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
427                                   /* GlobalOnly */ true, SingleChoice);
428       return false;
429     };
430     RFI.foreachUse(CombineIdentStruct);
431 
432     if (!Ident || !SingleChoice) {
433       // The IRBuilder uses the insertion block to get to the module, this is
434       // unfortunate but we work around it for now.
435       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
436         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
437             &F.getEntryBlock(), F.getEntryBlock().begin()));
438       // Create a fallback location if non was found.
439       // TODO: Use the debug locations of the calls instead.
440       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
441       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
442     }
443     return Ident;
444   }
445 
446   /// Try to eliminiate calls of \p RFI in \p F by reusing an existing one or
447   /// \p ReplVal if given.
448   bool deduplicateRuntimeCalls(Function &F,
449                                OMPInformationCache::RuntimeFunctionInfo &RFI,
450                                Value *ReplVal = nullptr) {
451     auto *UV = RFI.getUseVector(F);
452     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
453       return false;
454 
455     LLVM_DEBUG(
456         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
457                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
458 
459     assert((!ReplVal || (isa<Argument>(ReplVal) &&
460                          cast<Argument>(ReplVal)->getParent() == &F)) &&
461            "Unexpected replacement value!");
462 
463     // TODO: Use dominance to find a good position instead.
464     auto CanBeMoved = [](CallBase &CB) {
465       unsigned NumArgs = CB.getNumArgOperands();
466       if (NumArgs == 0)
467         return true;
468       if (CB.getArgOperand(0)->getType() != IdentPtr)
469         return false;
470       for (unsigned u = 1; u < NumArgs; ++u)
471         if (isa<Instruction>(CB.getArgOperand(u)))
472           return false;
473       return true;
474     };
475 
476     if (!ReplVal) {
477       for (Use *U : *UV)
478         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
479           if (!CanBeMoved(*CI))
480             continue;
481 
482           auto Remark = [&](OptimizationRemark OR) {
483             auto newLoc = &*F.getEntryBlock().getFirstInsertionPt();
484             return OR << "OpenMP runtime call "
485                       << ore::NV("OpenMPOptRuntime", RFI.Name) << " moved to "
486                       << ore::NV("OpenMPRuntimeMoves", newLoc->getDebugLoc());
487           };
488           emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeCodeMotion", Remark);
489 
490           CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
491           ReplVal = CI;
492           break;
493         }
494       if (!ReplVal)
495         return false;
496     }
497 
498     // If we use a call as a replacement value we need to make sure the ident is
499     // valid at the new location. For now we just pick a global one, either
500     // existing and used by one of the calls, or created from scratch.
501     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
502       if (CI->getNumArgOperands() > 0 &&
503           CI->getArgOperand(0)->getType() == IdentPtr) {
504         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
505                                                       /* GlobalOnly */ true);
506         CI->setArgOperand(0, Ident);
507       }
508     }
509 
510     bool Changed = false;
511     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
512       CallInst *CI = getCallIfRegularCall(U, &RFI);
513       if (!CI || CI == ReplVal || &F != &Caller)
514         return false;
515       assert(CI->getCaller() == &F && "Unexpected call!");
516 
517       auto Remark = [&](OptimizationRemark OR) {
518         return OR << "OpenMP runtime call "
519                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated";
520       };
521       emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeDeduplicated", Remark);
522 
523       CGUpdater.removeCallSite(*CI);
524       CI->replaceAllUsesWith(ReplVal);
525       CI->eraseFromParent();
526       ++NumOpenMPRuntimeCallsDeduplicated;
527       Changed = true;
528       return true;
529     };
530     RFI.foreachUse(ReplaceAndDeleteCB);
531 
532     return Changed;
533   }
534 
535   /// Collect arguments that represent the global thread id in \p GTIdArgs.
536   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
537     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
538     //       initialization. We could define an AbstractAttribute instead and
539     //       run the Attributor here once it can be run as an SCC pass.
540 
541     // Helper to check the argument \p ArgNo at all call sites of \p F for
542     // a GTId.
543     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
544       if (!F.hasLocalLinkage())
545         return false;
546       for (Use &U : F.uses()) {
547         if (CallInst *CI = getCallIfRegularCall(U)) {
548           Value *ArgOp = CI->getArgOperand(ArgNo);
549           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
550               getCallIfRegularCall(
551                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
552             continue;
553         }
554         return false;
555       }
556       return true;
557     };
558 
559     // Helper to identify uses of a GTId as GTId arguments.
560     auto AddUserArgs = [&](Value &GTId) {
561       for (Use &U : GTId.uses())
562         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
563           if (CI->isArgOperand(&U))
564             if (Function *Callee = CI->getCalledFunction())
565               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
566                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
567     };
568 
569     // The argument users of __kmpc_global_thread_num calls are GTIds.
570     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
571         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
572 
573     GlobThreadNumRFI.foreachUse([&](Use &U, Function &F) {
574       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
575         AddUserArgs(*CI);
576       return false;
577     });
578 
579     // Transitively search for more arguments by looking at the users of the
580     // ones we know already. During the search the GTIdArgs vector is extended
581     // so we cannot cache the size nor can we use a range based for.
582     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
583       AddUserArgs(*GTIdArgs[u]);
584   }
585 
586   /// Emit a remark generically
587   ///
588   /// This template function can be used to generically emit a remark. The
589   /// RemarkKind should be one of the following:
590   ///   - OptimizationRemark to indicate a successful optimization attempt
591   ///   - OptimizationRemarkMissed to report a failed optimization attempt
592   ///   - OptimizationRemarkAnalysis to provide additional information about an
593   ///     optimization attempt
594   ///
595   /// The remark is built using a callback function provided by the caller that
596   /// takes a RemarkKind as input and returns a RemarkKind.
597   template <typename RemarkKind,
598             typename RemarkCallBack = function_ref<RemarkKind(RemarkKind &&)>>
599   void emitRemark(Instruction *Inst, StringRef RemarkName,
600                   RemarkCallBack &&RemarkCB) {
601     Function *F = Inst->getParent()->getParent();
602     auto &ORE = OREGetter(F);
603 
604     ORE.emit(
605         [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, Inst)); });
606   }
607 
608   /// The underyling module.
609   Module &M;
610 
611   /// The SCC we are operating on.
612   SmallVectorImpl<Function *> &SCC;
613 
614   /// The slice of the module we are allowed to look at.
615   SmallPtrSetImpl<Function *> &ModuleSlice;
616 
617   /// Callback to update the call graph, the first argument is a removed call,
618   /// the second an optional replacement call.
619   CallGraphUpdater &CGUpdater;
620 
621   /// Callback to get an OptimizationRemarkEmitter from a Function *
622   OptimizationRemarkGetter OREGetter;
623 
624   /// OpenMP-specific information cache. Also Used for Attributor runs.
625   OMPInformationCache &OMPInfoCache;
626 };
627 } // namespace
628 
629 PreservedAnalyses OpenMPOptPass::run(LazyCallGraph::SCC &C,
630                                      CGSCCAnalysisManager &AM,
631                                      LazyCallGraph &CG, CGSCCUpdateResult &UR) {
632   if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule))
633     return PreservedAnalyses::all();
634 
635   if (DisableOpenMPOptimizations)
636     return PreservedAnalyses::all();
637 
638   SmallPtrSet<Function *, 16> ModuleSlice;
639   SmallVector<Function *, 16> SCC;
640   for (LazyCallGraph::Node &N : C) {
641     SCC.push_back(&N.getFunction());
642     ModuleSlice.insert(SCC.back());
643   }
644 
645   if (SCC.empty())
646     return PreservedAnalyses::all();
647 
648   FunctionAnalysisManager &FAM =
649       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
650 
651   AnalysisGetter AG(FAM);
652 
653   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
654     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
655   };
656 
657   CallGraphUpdater CGUpdater;
658   CGUpdater.initialize(CG, C, AM, UR);
659 
660   SetVector<Function *> Functions(SCC.begin(), SCC.end());
661   BumpPtrAllocator Allocator;
662   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
663                                 /*CGSCC*/ &Functions, ModuleSlice);
664 
665   // TODO: Compute the module slice we are allowed to look at.
666   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache);
667   bool Changed = OMPOpt.run();
668   (void)Changed;
669   return PreservedAnalyses::all();
670 }
671 
672 namespace {
673 
674 struct OpenMPOptLegacyPass : public CallGraphSCCPass {
675   CallGraphUpdater CGUpdater;
676   OpenMPInModule OMPInModule;
677   static char ID;
678 
679   OpenMPOptLegacyPass() : CallGraphSCCPass(ID) {
680     initializeOpenMPOptLegacyPassPass(*PassRegistry::getPassRegistry());
681   }
682 
683   void getAnalysisUsage(AnalysisUsage &AU) const override {
684     CallGraphSCCPass::getAnalysisUsage(AU);
685   }
686 
687   bool doInitialization(CallGraph &CG) override {
688     // Disable the pass if there is no OpenMP (runtime call) in the module.
689     containsOpenMP(CG.getModule(), OMPInModule);
690     return false;
691   }
692 
693   bool runOnSCC(CallGraphSCC &CGSCC) override {
694     if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule))
695       return false;
696     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
697       return false;
698 
699     SmallPtrSet<Function *, 16> ModuleSlice;
700     SmallVector<Function *, 16> SCC;
701     for (CallGraphNode *CGN : CGSCC)
702       if (Function *Fn = CGN->getFunction())
703         if (!Fn->isDeclaration()) {
704           SCC.push_back(Fn);
705           ModuleSlice.insert(Fn);
706         }
707 
708     if (SCC.empty())
709       return false;
710 
711     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
712     CGUpdater.initialize(CG, CGSCC);
713 
714     // Maintain a map of functions to avoid rebuilding the ORE
715     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
716     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
717       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
718       if (!ORE)
719         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
720       return *ORE;
721     };
722 
723     AnalysisGetter AG;
724     SetVector<Function *> Functions(SCC.begin(), SCC.end());
725     BumpPtrAllocator Allocator;
726     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
727                                   Allocator,
728                                   /*CGSCC*/ &Functions, ModuleSlice);
729 
730     // TODO: Compute the module slice we are allowed to look at.
731     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache);
732     return OMPOpt.run();
733   }
734 
735   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
736 };
737 
738 } // end anonymous namespace
739 
740 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) {
741   if (OMPInModule.isKnown())
742     return OMPInModule;
743 
744 #define OMP_RTL(_Enum, _Name, ...)                                             \
745   if (M.getFunction(_Name))                                                    \
746     return OMPInModule = true;
747 #include "llvm/Frontend/OpenMP/OMPKinds.def"
748   return OMPInModule = false;
749 }
750 
751 char OpenMPOptLegacyPass::ID = 0;
752 
753 INITIALIZE_PASS_BEGIN(OpenMPOptLegacyPass, "openmpopt",
754                       "OpenMP specific optimizations", false, false)
755 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
756 INITIALIZE_PASS_END(OpenMPOptLegacyPass, "openmpopt",
757                     "OpenMP specific optimizations", false, false)
758 
759 Pass *llvm::createOpenMPOptLegacyPass() { return new OpenMPOptLegacyPass(); }
760