1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/IPO/OpenMPOpt.h"
16 
17 #include "llvm/ADT/EnumeratedArray.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/CallGraphSCCPass.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Frontend/OpenMP/OMPConstants.h"
23 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Transforms/IPO.h"
27 #include "llvm/Transforms/IPO/Attributor.h"
28 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
29 
30 using namespace llvm;
31 using namespace omp;
32 using namespace types;
33 
34 #define DEBUG_TYPE "openmp-opt"
35 
36 static cl::opt<bool> DisableOpenMPOptimizations(
37     "openmp-opt-disable", cl::ZeroOrMore,
38     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
39     cl::init(false));
40 
41 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
42           "Number of OpenMP runtime calls deduplicated");
43 STATISTIC(NumOpenMPParallelRegionsDeleted,
44           "Number of OpenMP parallel regions deleted");
45 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
46           "Number of OpenMP runtime functions identified");
47 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
48           "Number of OpenMP runtime function uses identified");
49 
50 #if !defined(NDEBUG)
51 static constexpr auto TAG = "[" DEBUG_TYPE "]";
52 #endif
53 
54 namespace {
55 
56 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
57 /// Attributor runs.
58 struct OMPInformationCache : public InformationCache {
59   OMPInformationCache(Module &M, AnalysisGetter &AG,
60                       BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
61                       SmallPtrSetImpl<Function *> &ModuleSlice)
62       : InformationCache(M, AG, Allocator, CGSCC), ModuleSlice(ModuleSlice),
63         OMPBuilder(M) {
64     initializeTypes(M);
65     initializeRuntimeFunctions();
66 
67     OMPBuilder.initialize();
68   }
69 
70   /// Generic information that describes a runtime function
71   struct RuntimeFunctionInfo {
72 
73     /// The kind, as described by the RuntimeFunction enum.
74     RuntimeFunction Kind;
75 
76     /// The name of the function.
77     StringRef Name;
78 
79     /// Flag to indicate a variadic function.
80     bool IsVarArg;
81 
82     /// The return type of the function.
83     Type *ReturnType;
84 
85     /// The argument types of the function.
86     SmallVector<Type *, 8> ArgumentTypes;
87 
88     /// The declaration if available.
89     Function *Declaration = nullptr;
90 
91     /// Uses of this runtime function per function containing the use.
92     using UseVector = SmallVector<Use *, 16>;
93 
94     /// Return the vector of uses in function \p F.
95     UseVector &getOrCreateUseVector(Function *F) {
96       std::unique_ptr<UseVector> &UV = UsesMap[F];
97       if (!UV)
98         UV = std::make_unique<UseVector>();
99       return *UV;
100     }
101 
102     /// Return the vector of uses in function \p F or `nullptr` if there are
103     /// none.
104     const UseVector *getUseVector(Function &F) const {
105       auto I = UsesMap.find(&F);
106       if (I != UsesMap.end())
107         return I->second.get();
108       return nullptr;
109     }
110 
111     /// Return how many functions contain uses of this runtime function.
112     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
113 
114     /// Return the number of arguments (or the minimal number for variadic
115     /// functions).
116     size_t getNumArgs() const { return ArgumentTypes.size(); }
117 
118     /// Run the callback \p CB on each use and forget the use if the result is
119     /// true. The callback will be fed the function in which the use was
120     /// encountered as second argument.
121     void foreachUse(function_ref<bool(Use &, Function &)> CB) {
122       for (auto &It : UsesMap)
123         foreachUse(CB, It.first, It.second.get());
124     }
125 
126     /// Run the callback \p CB on each use within the function \p F and forget
127     /// the use if the result is true.
128     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F,
129                     UseVector *Uses = nullptr) {
130       SmallVector<unsigned, 8> ToBeDeleted;
131       ToBeDeleted.clear();
132 
133       unsigned Idx = 0;
134       UseVector &UV = Uses ? *Uses : getOrCreateUseVector(F);
135 
136       for (Use *U : UV) {
137         if (CB(*U, *F))
138           ToBeDeleted.push_back(Idx);
139         ++Idx;
140       }
141 
142       // Remove the to-be-deleted indices in reverse order as prior
143       // modifcations will not modify the smaller indices.
144       while (!ToBeDeleted.empty()) {
145         unsigned Idx = ToBeDeleted.pop_back_val();
146         UV[Idx] = UV.back();
147         UV.pop_back();
148       }
149     }
150 
151   private:
152     /// Map from functions to all uses of this runtime function contained in
153     /// them.
154     DenseMap<Function *, std::unique_ptr<UseVector>> UsesMap;
155   };
156 
157   /// The slice of the module we are allowed to look at.
158   SmallPtrSetImpl<Function *> &ModuleSlice;
159 
160   /// An OpenMP-IR-Builder instance
161   OpenMPIRBuilder OMPBuilder;
162 
163   /// Map from runtime function kind to the runtime function description.
164   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
165                   RuntimeFunction::OMPRTL___last>
166       RFIs;
167 
168   /// Returns true if the function declaration \p F matches the runtime
169   /// function types, that is, return type \p RTFRetType, and argument types
170   /// \p RTFArgTypes.
171   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
172                                   SmallVector<Type *, 8> &RTFArgTypes) {
173     // TODO: We should output information to the user (under debug output
174     //       and via remarks).
175 
176     if (!F)
177       return false;
178     if (F->getReturnType() != RTFRetType)
179       return false;
180     if (F->arg_size() != RTFArgTypes.size())
181       return false;
182 
183     auto RTFTyIt = RTFArgTypes.begin();
184     for (Argument &Arg : F->args()) {
185       if (Arg.getType() != *RTFTyIt)
186         return false;
187 
188       ++RTFTyIt;
189     }
190 
191     return true;
192   }
193 
194   /// Helper to initialize all runtime function information for those defined
195   /// in OpenMPKinds.def.
196   void initializeRuntimeFunctions() {
197     // Helper to collect all uses of the decleration in the UsesMap.
198     auto CollectUses = [&](RuntimeFunctionInfo &RFI) {
199       unsigned NumUses = 0;
200       if (!RFI.Declaration)
201         return NumUses;
202       OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
203 
204       NumOpenMPRuntimeFunctionsIdentified += 1;
205       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
206 
207       // TODO: We directly convert uses into proper calls and unknown uses.
208       for (Use &U : RFI.Declaration->uses()) {
209         if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
210           if (ModuleSlice.count(UserI->getFunction())) {
211             RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
212             ++NumUses;
213           }
214         } else {
215           RFI.getOrCreateUseVector(nullptr).push_back(&U);
216           ++NumUses;
217         }
218       }
219       return NumUses;
220     };
221 
222     Module &M = *((*ModuleSlice.begin())->getParent());
223 
224 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
225   {                                                                            \
226     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
227     Function *F = M.getFunction(_Name);                                        \
228     if (declMatchesRTFTypes(F, _ReturnType, ArgsTypes)) {                      \
229       auto &RFI = RFIs[_Enum];                                                 \
230       RFI.Kind = _Enum;                                                        \
231       RFI.Name = _Name;                                                        \
232       RFI.IsVarArg = _IsVarArg;                                                \
233       RFI.ReturnType = _ReturnType;                                            \
234       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
235       RFI.Declaration = F;                                                     \
236       unsigned NumUses = CollectUses(RFI);                                     \
237       (void)NumUses;                                                           \
238       LLVM_DEBUG({                                                             \
239         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
240                << " found\n";                                                  \
241         if (RFI.Declaration)                                                   \
242           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
243                  << RFI.getNumFunctionsWithUses()                              \
244                  << " different functions.\n";                                 \
245       });                                                                      \
246     }                                                                          \
247   }
248 #include "llvm/Frontend/OpenMP/OMPKinds.def"
249 
250     // TODO: We should attach the attributes defined in OMPKinds.def.
251   }
252 };
253 
254 struct OpenMPOpt {
255 
256   using OptimizationRemarkGetter =
257       function_ref<OptimizationRemarkEmitter &(Function *)>;
258 
259   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
260             OptimizationRemarkGetter OREGetter,
261             OMPInformationCache &OMPInfoCache)
262       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
263         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache) {}
264 
265   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
266   bool run() {
267     bool Changed = false;
268 
269     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
270                       << " functions in a slice with "
271                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
272 
273     Changed |= deduplicateRuntimeCalls();
274     Changed |= deleteParallelRegions();
275 
276     return Changed;
277   }
278 
279   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
280   /// given it has to be the callee or a nullptr is returned.
281   static CallInst *getCallIfRegularCall(
282       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
283     CallInst *CI = dyn_cast<CallInst>(U.getUser());
284     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
285         (!RFI || CI->getCalledFunction() == RFI->Declaration))
286       return CI;
287     return nullptr;
288   }
289 
290   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
291   /// the callee or a nullptr is returned.
292   static CallInst *getCallIfRegularCall(
293       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
294     CallInst *CI = dyn_cast<CallInst>(&V);
295     if (CI && !CI->hasOperandBundles() &&
296         (!RFI || CI->getCalledFunction() == RFI->Declaration))
297       return CI;
298     return nullptr;
299   }
300 
301 private:
302   /// Try to delete parallel regions if possible.
303   bool deleteParallelRegions() {
304     const unsigned CallbackCalleeOperand = 2;
305 
306     OMPInformationCache::RuntimeFunctionInfo &RFI =
307         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
308 
309     if (!RFI.Declaration)
310       return false;
311 
312     bool Changed = false;
313     auto DeleteCallCB = [&](Use &U, Function &) {
314       CallInst *CI = getCallIfRegularCall(U);
315       if (!CI)
316         return false;
317       auto *Fn = dyn_cast<Function>(
318           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
319       if (!Fn)
320         return false;
321       if (!Fn->onlyReadsMemory())
322         return false;
323       if (!Fn->hasFnAttribute(Attribute::WillReturn))
324         return false;
325 
326       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
327                         << CI->getCaller()->getName() << "\n");
328 
329       auto Remark = [&](OptimizationRemark OR) {
330         return OR << "Parallel region in "
331                   << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName())
332                   << " deleted";
333       };
334       emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion",
335                                      Remark);
336 
337       CGUpdater.removeCallSite(*CI);
338       CI->eraseFromParent();
339       Changed = true;
340       ++NumOpenMPParallelRegionsDeleted;
341       return true;
342     };
343 
344     RFI.foreachUse(DeleteCallCB);
345 
346     return Changed;
347   }
348 
349   /// Try to eliminiate runtime calls by reusing existing ones.
350   bool deduplicateRuntimeCalls() {
351     bool Changed = false;
352 
353     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
354         OMPRTL_omp_get_num_threads,
355         OMPRTL_omp_in_parallel,
356         OMPRTL_omp_get_cancellation,
357         OMPRTL_omp_get_thread_limit,
358         OMPRTL_omp_get_supported_active_levels,
359         OMPRTL_omp_get_level,
360         OMPRTL_omp_get_ancestor_thread_num,
361         OMPRTL_omp_get_team_size,
362         OMPRTL_omp_get_active_level,
363         OMPRTL_omp_in_final,
364         OMPRTL_omp_get_proc_bind,
365         OMPRTL_omp_get_num_places,
366         OMPRTL_omp_get_num_procs,
367         OMPRTL_omp_get_place_num,
368         OMPRTL_omp_get_partition_num_places,
369         OMPRTL_omp_get_partition_place_nums};
370 
371     // Global-tid is handled separately.
372     SmallSetVector<Value *, 16> GTIdArgs;
373     collectGlobalThreadIdArguments(GTIdArgs);
374     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
375                       << " global thread ID arguments\n");
376 
377     for (Function *F : SCC) {
378       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
379         deduplicateRuntimeCalls(*F,
380                                 OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
381 
382       // __kmpc_global_thread_num is special as we can replace it with an
383       // argument in enough cases to make it worth trying.
384       Value *GTIdArg = nullptr;
385       for (Argument &Arg : F->args())
386         if (GTIdArgs.count(&Arg)) {
387           GTIdArg = &Arg;
388           break;
389         }
390       Changed |= deduplicateRuntimeCalls(
391           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
392     }
393 
394     return Changed;
395   }
396 
397   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
398                                     bool GlobalOnly, bool &SingleChoice) {
399     if (CurrentIdent == NextIdent)
400       return CurrentIdent;
401 
402     // TODO: Figure out how to actually combine multiple debug locations. For
403     //       now we just keep an existing one if there is a single choice.
404     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
405       SingleChoice = !CurrentIdent;
406       return NextIdent;
407     }
408     return nullptr;
409   }
410 
411   /// Return an `struct ident_t*` value that represents the ones used in the
412   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
413   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
414   /// return value we create one from scratch. We also do not yet combine
415   /// information, e.g., the source locations, see combinedIdentStruct.
416   Value *
417   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
418                                  Function &F, bool GlobalOnly) {
419     bool SingleChoice = true;
420     Value *Ident = nullptr;
421     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
422       CallInst *CI = getCallIfRegularCall(U, &RFI);
423       if (!CI || &F != &Caller)
424         return false;
425       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
426                                   /* GlobalOnly */ true, SingleChoice);
427       return false;
428     };
429     RFI.foreachUse(CombineIdentStruct);
430 
431     if (!Ident || !SingleChoice) {
432       // The IRBuilder uses the insertion block to get to the module, this is
433       // unfortunate but we work around it for now.
434       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
435         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
436             &F.getEntryBlock(), F.getEntryBlock().begin()));
437       // Create a fallback location if non was found.
438       // TODO: Use the debug locations of the calls instead.
439       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
440       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
441     }
442     return Ident;
443   }
444 
445   /// Try to eliminiate calls of \p RFI in \p F by reusing an existing one or
446   /// \p ReplVal if given.
447   bool deduplicateRuntimeCalls(Function &F,
448                                OMPInformationCache::RuntimeFunctionInfo &RFI,
449                                Value *ReplVal = nullptr) {
450     auto *UV = RFI.getUseVector(F);
451     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
452       return false;
453 
454     LLVM_DEBUG(
455         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
456                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
457 
458     assert((!ReplVal || (isa<Argument>(ReplVal) &&
459                          cast<Argument>(ReplVal)->getParent() == &F)) &&
460            "Unexpected replacement value!");
461 
462     // TODO: Use dominance to find a good position instead.
463     auto CanBeMoved = [](CallBase &CB) {
464       unsigned NumArgs = CB.getNumArgOperands();
465       if (NumArgs == 0)
466         return true;
467       if (CB.getArgOperand(0)->getType() != IdentPtr)
468         return false;
469       for (unsigned u = 1; u < NumArgs; ++u)
470         if (isa<Instruction>(CB.getArgOperand(u)))
471           return false;
472       return true;
473     };
474 
475     if (!ReplVal) {
476       for (Use *U : *UV)
477         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
478           if (!CanBeMoved(*CI))
479             continue;
480 
481           auto Remark = [&](OptimizationRemark OR) {
482             auto newLoc = &*F.getEntryBlock().getFirstInsertionPt();
483             return OR << "OpenMP runtime call "
484                       << ore::NV("OpenMPOptRuntime", RFI.Name) << " moved to "
485                       << ore::NV("OpenMPRuntimeMoves", newLoc->getDebugLoc());
486           };
487           emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeCodeMotion", Remark);
488 
489           CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
490           ReplVal = CI;
491           break;
492         }
493       if (!ReplVal)
494         return false;
495     }
496 
497     // If we use a call as a replacement value we need to make sure the ident is
498     // valid at the new location. For now we just pick a global one, either
499     // existing and used by one of the calls, or created from scratch.
500     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
501       if (CI->getNumArgOperands() > 0 &&
502           CI->getArgOperand(0)->getType() == IdentPtr) {
503         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
504                                                       /* GlobalOnly */ true);
505         CI->setArgOperand(0, Ident);
506       }
507     }
508 
509     bool Changed = false;
510     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
511       CallInst *CI = getCallIfRegularCall(U, &RFI);
512       if (!CI || CI == ReplVal || &F != &Caller)
513         return false;
514       assert(CI->getCaller() == &F && "Unexpected call!");
515 
516       auto Remark = [&](OptimizationRemark OR) {
517         return OR << "OpenMP runtime call "
518                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated";
519       };
520       emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeDeduplicated", Remark);
521 
522       CGUpdater.removeCallSite(*CI);
523       CI->replaceAllUsesWith(ReplVal);
524       CI->eraseFromParent();
525       ++NumOpenMPRuntimeCallsDeduplicated;
526       Changed = true;
527       return true;
528     };
529     RFI.foreachUse(ReplaceAndDeleteCB);
530 
531     return Changed;
532   }
533 
534   /// Collect arguments that represent the global thread id in \p GTIdArgs.
535   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
536     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
537     //       initialization. We could define an AbstractAttribute instead and
538     //       run the Attributor here once it can be run as an SCC pass.
539 
540     // Helper to check the argument \p ArgNo at all call sites of \p F for
541     // a GTId.
542     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
543       if (!F.hasLocalLinkage())
544         return false;
545       for (Use &U : F.uses()) {
546         if (CallInst *CI = getCallIfRegularCall(U)) {
547           Value *ArgOp = CI->getArgOperand(ArgNo);
548           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
549               getCallIfRegularCall(
550                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
551             continue;
552         }
553         return false;
554       }
555       return true;
556     };
557 
558     // Helper to identify uses of a GTId as GTId arguments.
559     auto AddUserArgs = [&](Value &GTId) {
560       for (Use &U : GTId.uses())
561         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
562           if (CI->isArgOperand(&U))
563             if (Function *Callee = CI->getCalledFunction())
564               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
565                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
566     };
567 
568     // The argument users of __kmpc_global_thread_num calls are GTIds.
569     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
570         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
571 
572     GlobThreadNumRFI.foreachUse([&](Use &U, Function &F) {
573       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
574         AddUserArgs(*CI);
575       return false;
576     });
577 
578     // Transitively search for more arguments by looking at the users of the
579     // ones we know already. During the search the GTIdArgs vector is extended
580     // so we cannot cache the size nor can we use a range based for.
581     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
582       AddUserArgs(*GTIdArgs[u]);
583   }
584 
585   /// Emit a remark generically
586   ///
587   /// This template function can be used to generically emit a remark. The
588   /// RemarkKind should be one of the following:
589   ///   - OptimizationRemark to indicate a successful optimization attempt
590   ///   - OptimizationRemarkMissed to report a failed optimization attempt
591   ///   - OptimizationRemarkAnalysis to provide additional information about an
592   ///     optimization attempt
593   ///
594   /// The remark is built using a callback function provided by the caller that
595   /// takes a RemarkKind as input and returns a RemarkKind.
596   template <typename RemarkKind,
597             typename RemarkCallBack = function_ref<RemarkKind(RemarkKind &&)>>
598   void emitRemark(Instruction *Inst, StringRef RemarkName,
599                   RemarkCallBack &&RemarkCB) {
600     Function *F = Inst->getParent()->getParent();
601     auto &ORE = OREGetter(F);
602 
603     ORE.emit(
604         [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, Inst)); });
605   }
606 
607   /// The underyling module.
608   Module &M;
609 
610   /// The SCC we are operating on.
611   SmallVectorImpl<Function *> &SCC;
612 
613   /// Callback to update the call graph, the first argument is a removed call,
614   /// the second an optional replacement call.
615   CallGraphUpdater &CGUpdater;
616 
617   /// Callback to get an OptimizationRemarkEmitter from a Function *
618   OptimizationRemarkGetter OREGetter;
619 
620   /// OpenMP-specific information cache. Also Used for Attributor runs.
621   OMPInformationCache &OMPInfoCache;
622 };
623 } // namespace
624 
625 PreservedAnalyses OpenMPOptPass::run(LazyCallGraph::SCC &C,
626                                      CGSCCAnalysisManager &AM,
627                                      LazyCallGraph &CG, CGSCCUpdateResult &UR) {
628   if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule))
629     return PreservedAnalyses::all();
630 
631   if (DisableOpenMPOptimizations)
632     return PreservedAnalyses::all();
633 
634   SmallPtrSet<Function *, 16> ModuleSlice;
635   SmallVector<Function *, 16> SCC;
636   for (LazyCallGraph::Node &N : C) {
637     SCC.push_back(&N.getFunction());
638     ModuleSlice.insert(SCC.back());
639   }
640 
641   if (SCC.empty())
642     return PreservedAnalyses::all();
643 
644   FunctionAnalysisManager &FAM =
645       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
646 
647   AnalysisGetter AG(FAM);
648 
649   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
650     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
651   };
652 
653   CallGraphUpdater CGUpdater;
654   CGUpdater.initialize(CG, C, AM, UR);
655 
656   SetVector<Function *> Functions(SCC.begin(), SCC.end());
657   BumpPtrAllocator Allocator;
658   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
659                                 /*CGSCC*/ &Functions, ModuleSlice);
660 
661   // TODO: Compute the module slice we are allowed to look at.
662   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache);
663   bool Changed = OMPOpt.run();
664   (void)Changed;
665   return PreservedAnalyses::all();
666 }
667 
668 namespace {
669 
670 struct OpenMPOptLegacyPass : public CallGraphSCCPass {
671   CallGraphUpdater CGUpdater;
672   OpenMPInModule OMPInModule;
673   static char ID;
674 
675   OpenMPOptLegacyPass() : CallGraphSCCPass(ID) {
676     initializeOpenMPOptLegacyPassPass(*PassRegistry::getPassRegistry());
677   }
678 
679   void getAnalysisUsage(AnalysisUsage &AU) const override {
680     CallGraphSCCPass::getAnalysisUsage(AU);
681   }
682 
683   bool doInitialization(CallGraph &CG) override {
684     // Disable the pass if there is no OpenMP (runtime call) in the module.
685     containsOpenMP(CG.getModule(), OMPInModule);
686     return false;
687   }
688 
689   bool runOnSCC(CallGraphSCC &CGSCC) override {
690     if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule))
691       return false;
692     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
693       return false;
694 
695     SmallPtrSet<Function *, 16> ModuleSlice;
696     SmallVector<Function *, 16> SCC;
697     for (CallGraphNode *CGN : CGSCC)
698       if (Function *Fn = CGN->getFunction())
699         if (!Fn->isDeclaration()) {
700           SCC.push_back(Fn);
701           ModuleSlice.insert(Fn);
702         }
703 
704     if (SCC.empty())
705       return false;
706 
707     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
708     CGUpdater.initialize(CG, CGSCC);
709 
710     // Maintain a map of functions to avoid rebuilding the ORE
711     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
712     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
713       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
714       if (!ORE)
715         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
716       return *ORE;
717     };
718 
719     AnalysisGetter AG;
720     SetVector<Function *> Functions(SCC.begin(), SCC.end());
721     BumpPtrAllocator Allocator;
722     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
723                                   Allocator,
724                                   /*CGSCC*/ &Functions, ModuleSlice);
725 
726     // TODO: Compute the module slice we are allowed to look at.
727     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache);
728     return OMPOpt.run();
729   }
730 
731   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
732 };
733 
734 } // end anonymous namespace
735 
736 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) {
737   if (OMPInModule.isKnown())
738     return OMPInModule;
739 
740 #define OMP_RTL(_Enum, _Name, ...)                                             \
741   if (M.getFunction(_Name))                                                    \
742     return OMPInModule = true;
743 #include "llvm/Frontend/OpenMP/OMPKinds.def"
744   return OMPInModule = false;
745 }
746 
747 char OpenMPOptLegacyPass::ID = 0;
748 
749 INITIALIZE_PASS_BEGIN(OpenMPOptLegacyPass, "openmpopt",
750                       "OpenMP specific optimizations", false, false)
751 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
752 INITIALIZE_PASS_END(OpenMPOptLegacyPass, "openmpopt",
753                     "OpenMP specific optimizations", false, false)
754 
755 Pass *llvm::createOpenMPOptLegacyPass() { return new OpenMPOptLegacyPass(); }
756