1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/IPO/OpenMPOpt.h"
16 
17 #include "llvm/ADT/EnumeratedArray.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/CallGraphSCCPass.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Frontend/OpenMP/OMPConstants.h"
23 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Transforms/IPO.h"
27 #include "llvm/Transforms/IPO/Attributor.h"
28 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
29 
30 using namespace llvm;
31 using namespace omp;
32 
33 #define DEBUG_TYPE "openmp-opt"
34 
35 static cl::opt<bool> DisableOpenMPOptimizations(
36     "openmp-opt-disable", cl::ZeroOrMore,
37     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
38     cl::init(false));
39 
40 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
41                                     cl::Hidden);
42 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
43                                         cl::init(false), cl::Hidden);
44 
45 static cl::opt<bool> HideMemoryTransferLatency(
46     "openmp-hide-memory-transfer-latency",
47     cl::desc("[WIP] Tries to hide the latency of host to device memory"
48              " transfers"),
49     cl::Hidden, cl::init(false));
50 
51 
52 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
53           "Number of OpenMP runtime calls deduplicated");
54 STATISTIC(NumOpenMPParallelRegionsDeleted,
55           "Number of OpenMP parallel regions deleted");
56 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
57           "Number of OpenMP runtime functions identified");
58 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
59           "Number of OpenMP runtime function uses identified");
60 STATISTIC(NumOpenMPTargetRegionKernels,
61           "Number of OpenMP target region entry points (=kernels) identified");
62 STATISTIC(
63     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
64     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
65 
66 #if !defined(NDEBUG)
67 static constexpr auto TAG = "[" DEBUG_TYPE "]";
68 #endif
69 
70 /// Apply \p CB to all uses of \p F. If \p LookThroughConstantExprUses is
71 /// true, constant expression users are not given to \p CB but their uses are
72 /// traversed transitively.
73 template <typename CBTy>
74 static void foreachUse(Function &F, CBTy CB,
75                        bool LookThroughConstantExprUses = true) {
76   SmallVector<Use *, 8> Worklist(make_pointer_range(F.uses()));
77 
78   for (unsigned idx = 0; idx < Worklist.size(); ++idx) {
79     Use &U = *Worklist[idx];
80 
81     // Allow use in constant bitcasts and simply look through them.
82     if (LookThroughConstantExprUses && isa<ConstantExpr>(U.getUser())) {
83       for (Use &CEU : cast<ConstantExpr>(U.getUser())->uses())
84         Worklist.push_back(&CEU);
85       continue;
86     }
87 
88     CB(U);
89   }
90 }
91 
92 /// Helper struct to store tracked ICV values at specif instructions.
93 struct ICVValue {
94   Instruction *Inst;
95   Value *TrackedValue;
96 
97   ICVValue(Instruction *I, Value *Val) : Inst(I), TrackedValue(Val) {}
98 };
99 
100 namespace llvm {
101 
102 // Provide DenseMapInfo for ICVValue
103 template <> struct DenseMapInfo<ICVValue> {
104   using InstInfo = DenseMapInfo<Instruction *>;
105   using ValueInfo = DenseMapInfo<Value *>;
106 
107   static inline ICVValue getEmptyKey() {
108     return ICVValue(InstInfo::getEmptyKey(), ValueInfo::getEmptyKey());
109   };
110 
111   static inline ICVValue getTombstoneKey() {
112     return ICVValue(InstInfo::getTombstoneKey(), ValueInfo::getTombstoneKey());
113   };
114 
115   static unsigned getHashValue(const ICVValue &ICVVal) {
116     return detail::combineHashValue(
117         InstInfo::getHashValue(ICVVal.Inst),
118         ValueInfo::getHashValue(ICVVal.TrackedValue));
119   }
120 
121   static bool isEqual(const ICVValue &LHS, const ICVValue &RHS) {
122     return InstInfo::isEqual(LHS.Inst, RHS.Inst) &&
123            ValueInfo::isEqual(LHS.TrackedValue, RHS.TrackedValue);
124   }
125 };
126 
127 } // end namespace llvm
128 
129 namespace {
130 
131 struct AAICVTracker;
132 
133 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
134 /// Attributor runs.
135 struct OMPInformationCache : public InformationCache {
136   OMPInformationCache(Module &M, AnalysisGetter &AG,
137                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
138                       SmallPtrSetImpl<Kernel> &Kernels)
139       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
140         Kernels(Kernels) {
141     initializeModuleSlice(CGSCC);
142 
143     OMPBuilder.initialize();
144     initializeRuntimeFunctions();
145     initializeInternalControlVars();
146   }
147 
148   /// Generic information that describes an internal control variable.
149   struct InternalControlVarInfo {
150     /// The kind, as described by InternalControlVar enum.
151     InternalControlVar Kind;
152 
153     /// The name of the ICV.
154     StringRef Name;
155 
156     /// Environment variable associated with this ICV.
157     StringRef EnvVarName;
158 
159     /// Initial value kind.
160     ICVInitValue InitKind;
161 
162     /// Initial value.
163     ConstantInt *InitValue;
164 
165     /// Setter RTL function associated with this ICV.
166     RuntimeFunction Setter;
167 
168     /// Getter RTL function associated with this ICV.
169     RuntimeFunction Getter;
170 
171     /// RTL Function corresponding to the override clause of this ICV
172     RuntimeFunction Clause;
173   };
174 
175   /// Generic information that describes a runtime function
176   struct RuntimeFunctionInfo {
177 
178     /// The kind, as described by the RuntimeFunction enum.
179     RuntimeFunction Kind;
180 
181     /// The name of the function.
182     StringRef Name;
183 
184     /// Flag to indicate a variadic function.
185     bool IsVarArg;
186 
187     /// The return type of the function.
188     Type *ReturnType;
189 
190     /// The argument types of the function.
191     SmallVector<Type *, 8> ArgumentTypes;
192 
193     /// The declaration if available.
194     Function *Declaration = nullptr;
195 
196     /// Uses of this runtime function per function containing the use.
197     using UseVector = SmallVector<Use *, 16>;
198 
199     /// Clear UsesMap for runtime function.
200     void clearUsesMap() { UsesMap.clear(); }
201 
202     /// Boolean conversion that is true if the runtime function was found.
203     operator bool() const { return Declaration; }
204 
205     /// Return the vector of uses in function \p F.
206     UseVector &getOrCreateUseVector(Function *F) {
207       std::shared_ptr<UseVector> &UV = UsesMap[F];
208       if (!UV)
209         UV = std::make_shared<UseVector>();
210       return *UV;
211     }
212 
213     /// Return the vector of uses in function \p F or `nullptr` if there are
214     /// none.
215     const UseVector *getUseVector(Function &F) const {
216       auto I = UsesMap.find(&F);
217       if (I != UsesMap.end())
218         return I->second.get();
219       return nullptr;
220     }
221 
222     /// Return how many functions contain uses of this runtime function.
223     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
224 
225     /// Return the number of arguments (or the minimal number for variadic
226     /// functions).
227     size_t getNumArgs() const { return ArgumentTypes.size(); }
228 
229     /// Run the callback \p CB on each use and forget the use if the result is
230     /// true. The callback will be fed the function in which the use was
231     /// encountered as second argument.
232     void foreachUse(SmallVectorImpl<Function *> &SCC,
233                     function_ref<bool(Use &, Function &)> CB) {
234       for (Function *F : SCC)
235         foreachUse(CB, F);
236     }
237 
238     /// Run the callback \p CB on each use within the function \p F and forget
239     /// the use if the result is true.
240     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
241       SmallVector<unsigned, 8> ToBeDeleted;
242       ToBeDeleted.clear();
243 
244       unsigned Idx = 0;
245       UseVector &UV = getOrCreateUseVector(F);
246 
247       for (Use *U : UV) {
248         if (CB(*U, *F))
249           ToBeDeleted.push_back(Idx);
250         ++Idx;
251       }
252 
253       // Remove the to-be-deleted indices in reverse order as prior
254       // modifications will not modify the smaller indices.
255       while (!ToBeDeleted.empty()) {
256         unsigned Idx = ToBeDeleted.pop_back_val();
257         UV[Idx] = UV.back();
258         UV.pop_back();
259       }
260     }
261 
262   private:
263     /// Map from functions to all uses of this runtime function contained in
264     /// them.
265     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
266   };
267 
268   /// Initialize the ModuleSlice member based on \p SCC. ModuleSlices contains
269   /// (a subset of) all functions that we can look at during this SCC traversal.
270   /// This includes functions (transitively) called from the SCC and the
271   /// (transitive) callers of SCC functions. We also can look at a function if
272   /// there is a "reference edge", i.a., if the function somehow uses (!=calls)
273   /// a function in the SCC or a caller of a function in the SCC.
274   void initializeModuleSlice(SetVector<Function *> &SCC) {
275     ModuleSlice.insert(SCC.begin(), SCC.end());
276 
277     SmallPtrSet<Function *, 16> Seen;
278     SmallVector<Function *, 16> Worklist(SCC.begin(), SCC.end());
279     while (!Worklist.empty()) {
280       Function *F = Worklist.pop_back_val();
281       ModuleSlice.insert(F);
282 
283       for (Instruction &I : instructions(*F))
284         if (auto *CB = dyn_cast<CallBase>(&I))
285           if (Function *Callee = CB->getCalledFunction())
286             if (Seen.insert(Callee).second)
287               Worklist.push_back(Callee);
288     }
289 
290     Seen.clear();
291     Worklist.append(SCC.begin(), SCC.end());
292     while (!Worklist.empty()) {
293       Function *F = Worklist.pop_back_val();
294       ModuleSlice.insert(F);
295 
296       // Traverse all transitive uses.
297       foreachUse(*F, [&](Use &U) {
298         if (auto *UsrI = dyn_cast<Instruction>(U.getUser()))
299           if (Seen.insert(UsrI->getFunction()).second)
300             Worklist.push_back(UsrI->getFunction());
301       });
302     }
303   }
304 
305   /// The slice of the module we are allowed to look at.
306   SmallPtrSet<Function *, 8> ModuleSlice;
307 
308   /// An OpenMP-IR-Builder instance
309   OpenMPIRBuilder OMPBuilder;
310 
311   /// Map from runtime function kind to the runtime function description.
312   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
313                   RuntimeFunction::OMPRTL___last>
314       RFIs;
315 
316   /// Map from ICV kind to the ICV description.
317   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
318                   InternalControlVar::ICV___last>
319       ICVs;
320 
321   /// Helper to initialize all internal control variable information for those
322   /// defined in OMPKinds.def.
323   void initializeInternalControlVars() {
324 #define ICV_RT_SET(_Name, RTL)                                                 \
325   {                                                                            \
326     auto &ICV = ICVs[_Name];                                                   \
327     ICV.Setter = RTL;                                                          \
328   }
329 #define ICV_RT_GET(Name, RTL)                                                  \
330   {                                                                            \
331     auto &ICV = ICVs[Name];                                                    \
332     ICV.Getter = RTL;                                                          \
333   }
334 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
335   {                                                                            \
336     auto &ICV = ICVs[Enum];                                                    \
337     ICV.Name = _Name;                                                          \
338     ICV.Kind = Enum;                                                           \
339     ICV.InitKind = Init;                                                       \
340     ICV.EnvVarName = _EnvVarName;                                              \
341     switch (ICV.InitKind) {                                                    \
342     case ICV_IMPLEMENTATION_DEFINED:                                           \
343       ICV.InitValue = nullptr;                                                 \
344       break;                                                                   \
345     case ICV_ZERO:                                                             \
346       ICV.InitValue = ConstantInt::get(                                        \
347           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
348       break;                                                                   \
349     case ICV_FALSE:                                                            \
350       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
351       break;                                                                   \
352     case ICV_LAST:                                                             \
353       break;                                                                   \
354     }                                                                          \
355   }
356 #include "llvm/Frontend/OpenMP/OMPKinds.def"
357   }
358 
359   /// Returns true if the function declaration \p F matches the runtime
360   /// function types, that is, return type \p RTFRetType, and argument types
361   /// \p RTFArgTypes.
362   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
363                                   SmallVector<Type *, 8> &RTFArgTypes) {
364     // TODO: We should output information to the user (under debug output
365     //       and via remarks).
366 
367     if (!F)
368       return false;
369     if (F->getReturnType() != RTFRetType)
370       return false;
371     if (F->arg_size() != RTFArgTypes.size())
372       return false;
373 
374     auto RTFTyIt = RTFArgTypes.begin();
375     for (Argument &Arg : F->args()) {
376       if (Arg.getType() != *RTFTyIt)
377         return false;
378 
379       ++RTFTyIt;
380     }
381 
382     return true;
383   }
384 
385   // Helper to collect all uses of the declaration in the UsesMap.
386   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
387     unsigned NumUses = 0;
388     if (!RFI.Declaration)
389       return NumUses;
390     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
391 
392     if (CollectStats) {
393       NumOpenMPRuntimeFunctionsIdentified += 1;
394       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
395     }
396 
397     // TODO: We directly convert uses into proper calls and unknown uses.
398     for (Use &U : RFI.Declaration->uses()) {
399       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
400         if (ModuleSlice.count(UserI->getFunction())) {
401           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
402           ++NumUses;
403         }
404       } else {
405         RFI.getOrCreateUseVector(nullptr).push_back(&U);
406         ++NumUses;
407       }
408     }
409     return NumUses;
410   }
411 
412   // Helper function to recollect uses of all runtime functions.
413   void recollectUses() {
414     for (int Idx = 0; Idx < RFIs.size(); ++Idx) {
415       auto &RFI = RFIs[static_cast<RuntimeFunction>(Idx)];
416       RFI.clearUsesMap();
417       collectUses(RFI, /*CollectStats*/ false);
418     }
419   }
420 
421   /// Helper to initialize all runtime function information for those defined
422   /// in OpenMPKinds.def.
423   void initializeRuntimeFunctions() {
424     Module &M = *((*ModuleSlice.begin())->getParent());
425 
426     // Helper macros for handling __VA_ARGS__ in OMP_RTL
427 #define OMP_TYPE(VarName, ...)                                                 \
428   Type *VarName = OMPBuilder.VarName;                                          \
429   (void)VarName;
430 
431 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
432   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
433   (void)VarName##Ty;                                                           \
434   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
435   (void)VarName##PtrTy;
436 
437 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
438   FunctionType *VarName = OMPBuilder.VarName;                                  \
439   (void)VarName;                                                               \
440   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
441   (void)VarName##Ptr;
442 
443 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
444   StructType *VarName = OMPBuilder.VarName;                                    \
445   (void)VarName;                                                               \
446   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
447   (void)VarName##Ptr;
448 
449 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
450   {                                                                            \
451     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
452     Function *F = M.getFunction(_Name);                                        \
453     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
454       auto &RFI = RFIs[_Enum];                                                 \
455       RFI.Kind = _Enum;                                                        \
456       RFI.Name = _Name;                                                        \
457       RFI.IsVarArg = _IsVarArg;                                                \
458       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
459       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
460       RFI.Declaration = F;                                                     \
461       unsigned NumUses = collectUses(RFI);                                     \
462       (void)NumUses;                                                           \
463       LLVM_DEBUG({                                                             \
464         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
465                << " found\n";                                                  \
466         if (RFI.Declaration)                                                   \
467           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
468                  << RFI.getNumFunctionsWithUses()                              \
469                  << " different functions.\n";                                 \
470       });                                                                      \
471     }                                                                          \
472   }
473 #include "llvm/Frontend/OpenMP/OMPKinds.def"
474 
475     // TODO: We should attach the attributes defined in OMPKinds.def.
476   }
477 
478   /// Collection of known kernels (\see Kernel) in the module.
479   SmallPtrSetImpl<Kernel> &Kernels;
480 };
481 
482 struct OpenMPOpt {
483 
484   using OptimizationRemarkGetter =
485       function_ref<OptimizationRemarkEmitter &(Function *)>;
486 
487   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
488             OptimizationRemarkGetter OREGetter,
489             OMPInformationCache &OMPInfoCache, Attributor &A)
490       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
491         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
492 
493   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
494   bool run() {
495     if (SCC.empty())
496       return false;
497 
498     bool Changed = false;
499 
500     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
501                       << " functions in a slice with "
502                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
503 
504     if (PrintICVValues)
505       printICVs();
506     if (PrintOpenMPKernels)
507       printKernels();
508 
509     Changed |= rewriteDeviceCodeStateMachine();
510 
511     Changed |= runAttributor();
512 
513     // Recollect uses, in case Attributor deleted any.
514     OMPInfoCache.recollectUses();
515 
516     Changed |= deduplicateRuntimeCalls();
517     Changed |= deleteParallelRegions();
518     if (HideMemoryTransferLatency)
519       Changed |= hideMemTransfersLatency();
520 
521     return Changed;
522   }
523 
524   /// Print initial ICV values for testing.
525   /// FIXME: This should be done from the Attributor once it is added.
526   void printICVs() const {
527     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel};
528 
529     for (Function *F : OMPInfoCache.ModuleSlice) {
530       for (auto ICV : ICVs) {
531         auto ICVInfo = OMPInfoCache.ICVs[ICV];
532         auto Remark = [&](OptimizationRemark OR) {
533           return OR << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
534                     << " Value: "
535                     << (ICVInfo.InitValue
536                             ? ICVInfo.InitValue->getValue().toString(10, true)
537                             : "IMPLEMENTATION_DEFINED");
538         };
539 
540         emitRemarkOnFunction(F, "OpenMPICVTracker", Remark);
541       }
542     }
543   }
544 
545   /// Print OpenMP GPU kernels for testing.
546   void printKernels() const {
547     for (Function *F : SCC) {
548       if (!OMPInfoCache.Kernels.count(F))
549         continue;
550 
551       auto Remark = [&](OptimizationRemark OR) {
552         return OR << "OpenMP GPU kernel "
553                   << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
554       };
555 
556       emitRemarkOnFunction(F, "OpenMPGPU", Remark);
557     }
558   }
559 
560   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
561   /// given it has to be the callee or a nullptr is returned.
562   static CallInst *getCallIfRegularCall(
563       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
564     CallInst *CI = dyn_cast<CallInst>(U.getUser());
565     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
566         (!RFI || CI->getCalledFunction() == RFI->Declaration))
567       return CI;
568     return nullptr;
569   }
570 
571   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
572   /// the callee or a nullptr is returned.
573   static CallInst *getCallIfRegularCall(
574       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
575     CallInst *CI = dyn_cast<CallInst>(&V);
576     if (CI && !CI->hasOperandBundles() &&
577         (!RFI || CI->getCalledFunction() == RFI->Declaration))
578       return CI;
579     return nullptr;
580   }
581 
582 private:
583   /// Try to delete parallel regions if possible.
584   bool deleteParallelRegions() {
585     const unsigned CallbackCalleeOperand = 2;
586 
587     OMPInformationCache::RuntimeFunctionInfo &RFI =
588         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
589 
590     if (!RFI.Declaration)
591       return false;
592 
593     bool Changed = false;
594     auto DeleteCallCB = [&](Use &U, Function &) {
595       CallInst *CI = getCallIfRegularCall(U);
596       if (!CI)
597         return false;
598       auto *Fn = dyn_cast<Function>(
599           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
600       if (!Fn)
601         return false;
602       if (!Fn->onlyReadsMemory())
603         return false;
604       if (!Fn->hasFnAttribute(Attribute::WillReturn))
605         return false;
606 
607       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
608                         << CI->getCaller()->getName() << "\n");
609 
610       auto Remark = [&](OptimizationRemark OR) {
611         return OR << "Parallel region in "
612                   << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName())
613                   << " deleted";
614       };
615       emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion",
616                                      Remark);
617 
618       CGUpdater.removeCallSite(*CI);
619       CI->eraseFromParent();
620       Changed = true;
621       ++NumOpenMPParallelRegionsDeleted;
622       return true;
623     };
624 
625     RFI.foreachUse(SCC, DeleteCallCB);
626 
627     return Changed;
628   }
629 
630   /// Try to eliminate runtime calls by reusing existing ones.
631   bool deduplicateRuntimeCalls() {
632     bool Changed = false;
633 
634     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
635         OMPRTL_omp_get_num_threads,
636         OMPRTL_omp_in_parallel,
637         OMPRTL_omp_get_cancellation,
638         OMPRTL_omp_get_thread_limit,
639         OMPRTL_omp_get_supported_active_levels,
640         OMPRTL_omp_get_level,
641         OMPRTL_omp_get_ancestor_thread_num,
642         OMPRTL_omp_get_team_size,
643         OMPRTL_omp_get_active_level,
644         OMPRTL_omp_in_final,
645         OMPRTL_omp_get_proc_bind,
646         OMPRTL_omp_get_num_places,
647         OMPRTL_omp_get_num_procs,
648         OMPRTL_omp_get_place_num,
649         OMPRTL_omp_get_partition_num_places,
650         OMPRTL_omp_get_partition_place_nums};
651 
652     // Global-tid is handled separately.
653     SmallSetVector<Value *, 16> GTIdArgs;
654     collectGlobalThreadIdArguments(GTIdArgs);
655     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
656                       << " global thread ID arguments\n");
657 
658     for (Function *F : SCC) {
659       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
660         Changed |= deduplicateRuntimeCalls(
661             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
662 
663       // __kmpc_global_thread_num is special as we can replace it with an
664       // argument in enough cases to make it worth trying.
665       Value *GTIdArg = nullptr;
666       for (Argument &Arg : F->args())
667         if (GTIdArgs.count(&Arg)) {
668           GTIdArg = &Arg;
669           break;
670         }
671       Changed |= deduplicateRuntimeCalls(
672           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
673     }
674 
675     return Changed;
676   }
677 
678   /// Tries to hide the latency of runtime calls that involve host to
679   /// device memory transfers by splitting them into their "issue" and "wait"
680   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
681   /// moved downards as much as possible. The "issue" issues the memory transfer
682   /// asynchronously, returning a handle. The "wait" waits in the returned
683   /// handle for the memory transfer to finish.
684   bool hideMemTransfersLatency() {
685     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
686     bool Changed = false;
687     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
688       auto *RTCall = getCallIfRegularCall(U, &RFI);
689       if (!RTCall)
690         return false;
691 
692       // TODO: Check if can be moved upwards.
693       bool WasSplit = false;
694       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
695       if (WaitMovementPoint)
696         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
697 
698       Changed |= WasSplit;
699       return WasSplit;
700     };
701     RFI.foreachUse(SCC, SplitMemTransfers);
702 
703     return Changed;
704   }
705 
706   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
707   /// moved. Returns nullptr if the movement is not possible, or not worth it.
708   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
709     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
710     //  Make it traverse the CFG.
711 
712     Instruction *CurrentI = &RuntimeCall;
713     bool IsWorthIt = false;
714     while ((CurrentI = CurrentI->getNextNode())) {
715 
716       // TODO: Once we detect the regions to be offloaded we should use the
717       //  alias analysis manager to check if CurrentI may modify one of
718       //  the offloaded regions.
719       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
720         if (IsWorthIt)
721           return CurrentI;
722 
723         return nullptr;
724       }
725 
726       // FIXME: For now if we move it over anything without side effect
727       //  is worth it.
728       IsWorthIt = true;
729     }
730 
731     // Return end of BasicBlock.
732     return RuntimeCall.getParent()->getTerminator();
733   }
734 
735   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
736   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
737                                Instruction &WaitMovementPoint) {
738     auto &IRBuilder = OMPInfoCache.OMPBuilder;
739     // Add "issue" runtime call declaration:
740     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
741     //   i8**, i8**, i64*, i64*)
742     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
743         M, OMPRTL___tgt_target_data_begin_mapper_issue);
744 
745     // Change RuntimeCall call site for its asynchronous version.
746     SmallVector<Value *, 8> Args;
747     for (auto &Arg : RuntimeCall.args())
748       Args.push_back(Arg.get());
749 
750     CallInst *IssueCallsite =
751         CallInst::Create(IssueDecl, Args, "handle", &RuntimeCall);
752     RuntimeCall.eraseFromParent();
753 
754     // Add "wait" runtime call declaration:
755     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
756     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
757         M, OMPRTL___tgt_target_data_begin_mapper_wait);
758 
759     // Add call site to WaitDecl.
760     Value *WaitParams[2] = {
761         IssueCallsite->getArgOperand(0), // device_id.
762         IssueCallsite // returned handle.
763     };
764     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
765 
766     return true;
767   }
768 
769   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
770                                     bool GlobalOnly, bool &SingleChoice) {
771     if (CurrentIdent == NextIdent)
772       return CurrentIdent;
773 
774     // TODO: Figure out how to actually combine multiple debug locations. For
775     //       now we just keep an existing one if there is a single choice.
776     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
777       SingleChoice = !CurrentIdent;
778       return NextIdent;
779     }
780     return nullptr;
781   }
782 
783   /// Return an `struct ident_t*` value that represents the ones used in the
784   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
785   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
786   /// return value we create one from scratch. We also do not yet combine
787   /// information, e.g., the source locations, see combinedIdentStruct.
788   Value *
789   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
790                                  Function &F, bool GlobalOnly) {
791     bool SingleChoice = true;
792     Value *Ident = nullptr;
793     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
794       CallInst *CI = getCallIfRegularCall(U, &RFI);
795       if (!CI || &F != &Caller)
796         return false;
797       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
798                                   /* GlobalOnly */ true, SingleChoice);
799       return false;
800     };
801     RFI.foreachUse(SCC, CombineIdentStruct);
802 
803     if (!Ident || !SingleChoice) {
804       // The IRBuilder uses the insertion block to get to the module, this is
805       // unfortunate but we work around it for now.
806       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
807         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
808             &F.getEntryBlock(), F.getEntryBlock().begin()));
809       // Create a fallback location if non was found.
810       // TODO: Use the debug locations of the calls instead.
811       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
812       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
813     }
814     return Ident;
815   }
816 
817   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
818   /// \p ReplVal if given.
819   bool deduplicateRuntimeCalls(Function &F,
820                                OMPInformationCache::RuntimeFunctionInfo &RFI,
821                                Value *ReplVal = nullptr) {
822     auto *UV = RFI.getUseVector(F);
823     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
824       return false;
825 
826     LLVM_DEBUG(
827         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
828                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
829 
830     assert((!ReplVal || (isa<Argument>(ReplVal) &&
831                          cast<Argument>(ReplVal)->getParent() == &F)) &&
832            "Unexpected replacement value!");
833 
834     // TODO: Use dominance to find a good position instead.
835     auto CanBeMoved = [this](CallBase &CB) {
836       unsigned NumArgs = CB.getNumArgOperands();
837       if (NumArgs == 0)
838         return true;
839       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
840         return false;
841       for (unsigned u = 1; u < NumArgs; ++u)
842         if (isa<Instruction>(CB.getArgOperand(u)))
843           return false;
844       return true;
845     };
846 
847     if (!ReplVal) {
848       for (Use *U : *UV)
849         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
850           if (!CanBeMoved(*CI))
851             continue;
852 
853           auto Remark = [&](OptimizationRemark OR) {
854             auto newLoc = &*F.getEntryBlock().getFirstInsertionPt();
855             return OR << "OpenMP runtime call "
856                       << ore::NV("OpenMPOptRuntime", RFI.Name) << " moved to "
857                       << ore::NV("OpenMPRuntimeMoves", newLoc->getDebugLoc());
858           };
859           emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeCodeMotion", Remark);
860 
861           CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
862           ReplVal = CI;
863           break;
864         }
865       if (!ReplVal)
866         return false;
867     }
868 
869     // If we use a call as a replacement value we need to make sure the ident is
870     // valid at the new location. For now we just pick a global one, either
871     // existing and used by one of the calls, or created from scratch.
872     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
873       if (CI->getNumArgOperands() > 0 &&
874           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
875         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
876                                                       /* GlobalOnly */ true);
877         CI->setArgOperand(0, Ident);
878       }
879     }
880 
881     bool Changed = false;
882     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
883       CallInst *CI = getCallIfRegularCall(U, &RFI);
884       if (!CI || CI == ReplVal || &F != &Caller)
885         return false;
886       assert(CI->getCaller() == &F && "Unexpected call!");
887 
888       auto Remark = [&](OptimizationRemark OR) {
889         return OR << "OpenMP runtime call "
890                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated";
891       };
892       emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeDeduplicated", Remark);
893 
894       CGUpdater.removeCallSite(*CI);
895       CI->replaceAllUsesWith(ReplVal);
896       CI->eraseFromParent();
897       ++NumOpenMPRuntimeCallsDeduplicated;
898       Changed = true;
899       return true;
900     };
901     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
902 
903     return Changed;
904   }
905 
906   /// Collect arguments that represent the global thread id in \p GTIdArgs.
907   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
908     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
909     //       initialization. We could define an AbstractAttribute instead and
910     //       run the Attributor here once it can be run as an SCC pass.
911 
912     // Helper to check the argument \p ArgNo at all call sites of \p F for
913     // a GTId.
914     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
915       if (!F.hasLocalLinkage())
916         return false;
917       for (Use &U : F.uses()) {
918         if (CallInst *CI = getCallIfRegularCall(U)) {
919           Value *ArgOp = CI->getArgOperand(ArgNo);
920           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
921               getCallIfRegularCall(
922                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
923             continue;
924         }
925         return false;
926       }
927       return true;
928     };
929 
930     // Helper to identify uses of a GTId as GTId arguments.
931     auto AddUserArgs = [&](Value &GTId) {
932       for (Use &U : GTId.uses())
933         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
934           if (CI->isArgOperand(&U))
935             if (Function *Callee = CI->getCalledFunction())
936               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
937                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
938     };
939 
940     // The argument users of __kmpc_global_thread_num calls are GTIds.
941     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
942         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
943 
944     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
945       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
946         AddUserArgs(*CI);
947       return false;
948     });
949 
950     // Transitively search for more arguments by looking at the users of the
951     // ones we know already. During the search the GTIdArgs vector is extended
952     // so we cannot cache the size nor can we use a range based for.
953     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
954       AddUserArgs(*GTIdArgs[u]);
955   }
956 
957   /// Kernel (=GPU) optimizations and utility functions
958   ///
959   ///{{
960 
961   /// Check if \p F is a kernel, hence entry point for target offloading.
962   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
963 
964   /// Cache to remember the unique kernel for a function.
965   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
966 
967   /// Find the unique kernel that will execute \p F, if any.
968   Kernel getUniqueKernelFor(Function &F);
969 
970   /// Find the unique kernel that will execute \p I, if any.
971   Kernel getUniqueKernelFor(Instruction &I) {
972     return getUniqueKernelFor(*I.getFunction());
973   }
974 
975   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
976   /// the cases we can avoid taking the address of a function.
977   bool rewriteDeviceCodeStateMachine();
978 
979   ///
980   ///}}
981 
982   /// Emit a remark generically
983   ///
984   /// This template function can be used to generically emit a remark. The
985   /// RemarkKind should be one of the following:
986   ///   - OptimizationRemark to indicate a successful optimization attempt
987   ///   - OptimizationRemarkMissed to report a failed optimization attempt
988   ///   - OptimizationRemarkAnalysis to provide additional information about an
989   ///     optimization attempt
990   ///
991   /// The remark is built using a callback function provided by the caller that
992   /// takes a RemarkKind as input and returns a RemarkKind.
993   template <typename RemarkKind,
994             typename RemarkCallBack = function_ref<RemarkKind(RemarkKind &&)>>
995   void emitRemark(Instruction *Inst, StringRef RemarkName,
996                   RemarkCallBack &&RemarkCB) const {
997     Function *F = Inst->getParent()->getParent();
998     auto &ORE = OREGetter(F);
999 
1000     ORE.emit(
1001         [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, Inst)); });
1002   }
1003 
1004   /// Emit a remark on a function. Since only OptimizationRemark is supporting
1005   /// this, it can't be made generic.
1006   void
1007   emitRemarkOnFunction(Function *F, StringRef RemarkName,
1008                        function_ref<OptimizationRemark(OptimizationRemark &&)>
1009                            &&RemarkCB) const {
1010     auto &ORE = OREGetter(F);
1011 
1012     ORE.emit([&]() {
1013       return RemarkCB(OptimizationRemark(DEBUG_TYPE, RemarkName, F));
1014     });
1015   }
1016 
1017   /// The underlying module.
1018   Module &M;
1019 
1020   /// The SCC we are operating on.
1021   SmallVectorImpl<Function *> &SCC;
1022 
1023   /// Callback to update the call graph, the first argument is a removed call,
1024   /// the second an optional replacement call.
1025   CallGraphUpdater &CGUpdater;
1026 
1027   /// Callback to get an OptimizationRemarkEmitter from a Function *
1028   OptimizationRemarkGetter OREGetter;
1029 
1030   /// OpenMP-specific information cache. Also Used for Attributor runs.
1031   OMPInformationCache &OMPInfoCache;
1032 
1033   /// Attributor instance.
1034   Attributor &A;
1035 
1036   /// Helper function to run Attributor on SCC.
1037   bool runAttributor() {
1038     if (SCC.empty())
1039       return false;
1040 
1041     registerAAs();
1042 
1043     ChangeStatus Changed = A.run();
1044 
1045     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1046                       << " functions, result: " << Changed << ".\n");
1047 
1048     return Changed == ChangeStatus::CHANGED;
1049   }
1050 
1051   /// Populate the Attributor with abstract attribute opportunities in the
1052   /// function.
1053   void registerAAs() {
1054     for (Function *F : SCC) {
1055       if (F->isDeclaration())
1056         continue;
1057 
1058       A.getOrCreateAAFor<AAICVTracker>(IRPosition::function(*F));
1059     }
1060   }
1061 };
1062 
1063 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1064   if (!OMPInfoCache.ModuleSlice.count(&F))
1065     return nullptr;
1066 
1067   // Use a scope to keep the lifetime of the CachedKernel short.
1068   {
1069     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1070     if (CachedKernel)
1071       return *CachedKernel;
1072 
1073     // TODO: We should use an AA to create an (optimistic and callback
1074     //       call-aware) call graph. For now we stick to simple patterns that
1075     //       are less powerful, basically the worst fixpoint.
1076     if (isKernel(F)) {
1077       CachedKernel = Kernel(&F);
1078       return *CachedKernel;
1079     }
1080 
1081     CachedKernel = nullptr;
1082     if (!F.hasLocalLinkage())
1083       return nullptr;
1084   }
1085 
1086   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1087     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1088       // Allow use in equality comparisons.
1089       if (Cmp->isEquality())
1090         return getUniqueKernelFor(*Cmp);
1091       return nullptr;
1092     }
1093     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1094       // Allow direct calls.
1095       if (CB->isCallee(&U))
1096         return getUniqueKernelFor(*CB);
1097       // Allow the use in __kmpc_kernel_prepare_parallel calls.
1098       if (Function *Callee = CB->getCalledFunction())
1099         if (Callee->getName() == "__kmpc_kernel_prepare_parallel")
1100           return getUniqueKernelFor(*CB);
1101       return nullptr;
1102     }
1103     // Disallow every other use.
1104     return nullptr;
1105   };
1106 
1107   // TODO: In the future we want to track more than just a unique kernel.
1108   SmallPtrSet<Kernel, 2> PotentialKernels;
1109   foreachUse(F, [&](const Use &U) {
1110     PotentialKernels.insert(GetUniqueKernelForUse(U));
1111   });
1112 
1113   Kernel K = nullptr;
1114   if (PotentialKernels.size() == 1)
1115     K = *PotentialKernels.begin();
1116 
1117   // Cache the result.
1118   UniqueKernelMap[&F] = K;
1119 
1120   return K;
1121 }
1122 
1123 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1124   OMPInformationCache::RuntimeFunctionInfo &KernelPrepareParallelRFI =
1125       OMPInfoCache.RFIs[OMPRTL___kmpc_kernel_prepare_parallel];
1126 
1127   bool Changed = false;
1128   if (!KernelPrepareParallelRFI)
1129     return Changed;
1130 
1131   for (Function *F : SCC) {
1132 
1133     // Check if the function is uses in a __kmpc_kernel_prepare_parallel call at
1134     // all.
1135     bool UnknownUse = false;
1136     bool KernelPrepareUse = false;
1137     unsigned NumDirectCalls = 0;
1138 
1139     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1140     foreachUse(*F, [&](Use &U) {
1141       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1142         if (CB->isCallee(&U)) {
1143           ++NumDirectCalls;
1144           return;
1145         }
1146 
1147       if (isa<ICmpInst>(U.getUser())) {
1148         ToBeReplacedStateMachineUses.push_back(&U);
1149         return;
1150       }
1151       if (!KernelPrepareUse && OpenMPOpt::getCallIfRegularCall(
1152                                    *U.getUser(), &KernelPrepareParallelRFI)) {
1153         KernelPrepareUse = true;
1154         ToBeReplacedStateMachineUses.push_back(&U);
1155         return;
1156       }
1157       UnknownUse = true;
1158     });
1159 
1160     // Do not emit a remark if we haven't seen a __kmpc_kernel_prepare_parallel
1161     // use.
1162     if (!KernelPrepareUse)
1163       continue;
1164 
1165     {
1166       auto Remark = [&](OptimizationRemark OR) {
1167         return OR << "Found a parallel region that is called in a target "
1168                      "region but not part of a combined target construct nor "
1169                      "nesed inside a target construct without intermediate "
1170                      "code. This can lead to excessive register usage for "
1171                      "unrelated target regions in the same translation unit "
1172                      "due to spurious call edges assumed by ptxas.";
1173       };
1174       emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", Remark);
1175     }
1176 
1177     // If this ever hits, we should investigate.
1178     // TODO: Checking the number of uses is not a necessary restriction and
1179     // should be lifted.
1180     if (UnknownUse || NumDirectCalls != 1 ||
1181         ToBeReplacedStateMachineUses.size() != 2) {
1182       {
1183         auto Remark = [&](OptimizationRemark OR) {
1184           return OR << "Parallel region is used in "
1185                     << (UnknownUse ? "unknown" : "unexpected")
1186                     << " ways; will not attempt to rewrite the state machine.";
1187         };
1188         emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", Remark);
1189       }
1190       continue;
1191     }
1192 
1193     // Even if we have __kmpc_kernel_prepare_parallel calls, we (for now) give
1194     // up if the function is not called from a unique kernel.
1195     Kernel K = getUniqueKernelFor(*F);
1196     if (!K) {
1197       {
1198         auto Remark = [&](OptimizationRemark OR) {
1199           return OR << "Parallel region is not known to be called from a "
1200                        "unique single target region, maybe the surrounding "
1201                        "function has external linkage?; will not attempt to "
1202                        "rewrite the state machine use.";
1203         };
1204         emitRemarkOnFunction(F, "OpenMPParallelRegionInMultipleKernesl",
1205                              Remark);
1206       }
1207       continue;
1208     }
1209 
1210     // We now know F is a parallel body function called only from the kernel K.
1211     // We also identified the state machine uses in which we replace the
1212     // function pointer by a new global symbol for identification purposes. This
1213     // ensures only direct calls to the function are left.
1214 
1215     {
1216       auto RemarkParalleRegion = [&](OptimizationRemark OR) {
1217         return OR << "Specialize parallel region that is only reached from a "
1218                      "single target region to avoid spurious call edges and "
1219                      "excessive register usage in other target regions. "
1220                      "(parallel region ID: "
1221                   << ore::NV("OpenMPParallelRegion", F->getName())
1222                   << ", kernel ID: "
1223                   << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1224       };
1225       emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD",
1226                            RemarkParalleRegion);
1227       auto RemarkKernel = [&](OptimizationRemark OR) {
1228         return OR << "Target region containing the parallel region that is "
1229                      "specialized. (parallel region ID: "
1230                   << ore::NV("OpenMPParallelRegion", F->getName())
1231                   << ", kernel ID: "
1232                   << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1233       };
1234       emitRemarkOnFunction(K, "OpenMPParallelRegionInNonSPMD", RemarkKernel);
1235     }
1236 
1237     Module &M = *F->getParent();
1238     Type *Int8Ty = Type::getInt8Ty(M.getContext());
1239 
1240     auto *ID = new GlobalVariable(
1241         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
1242         UndefValue::get(Int8Ty), F->getName() + ".ID");
1243 
1244     for (Use *U : ToBeReplacedStateMachineUses)
1245       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
1246 
1247     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
1248 
1249     Changed = true;
1250   }
1251 
1252   return Changed;
1253 }
1254 
1255 /// Abstract Attribute for tracking ICV values.
1256 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
1257   using Base = StateWrapper<BooleanState, AbstractAttribute>;
1258   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
1259 
1260   /// Returns true if value is assumed to be tracked.
1261   bool isAssumedTracked() const { return getAssumed(); }
1262 
1263   /// Returns true if value is known to be tracked.
1264   bool isKnownTracked() const { return getAssumed(); }
1265 
1266   /// Create an abstract attribute biew for the position \p IRP.
1267   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
1268 
1269   /// Return the value with which \p I can be replaced for specific \p ICV.
1270   virtual Value *getReplacementValue(InternalControlVar ICV,
1271                                      const Instruction *I, Attributor &A) = 0;
1272 
1273   /// See AbstractAttribute::getName()
1274   const std::string getName() const override { return "AAICVTracker"; }
1275 
1276   /// See AbstractAttribute::getIdAddr()
1277   const char *getIdAddr() const override { return &ID; }
1278 
1279   /// This function should return true if the type of the \p AA is AAICVTracker
1280   static bool classof(const AbstractAttribute *AA) {
1281     return (AA->getIdAddr() == &ID);
1282   }
1283 
1284   static const char ID;
1285 };
1286 
1287 struct AAICVTrackerFunction : public AAICVTracker {
1288   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
1289       : AAICVTracker(IRP, A) {}
1290 
1291   // FIXME: come up with better string.
1292   const std::string getAsStr() const override { return "ICVTracker"; }
1293 
1294   // FIXME: come up with some stats.
1295   void trackStatistics() const override {}
1296 
1297   /// TODO: decide whether to deduplicate here, or use current
1298   /// deduplicateRuntimeCalls function.
1299   ChangeStatus manifest(Attributor &A) override {
1300     ChangeStatus Changed = ChangeStatus::UNCHANGED;
1301 
1302     for (InternalControlVar &ICV : TrackableICVs)
1303       if (deduplicateICVGetters(ICV, A))
1304         Changed = ChangeStatus::CHANGED;
1305 
1306     return Changed;
1307   }
1308 
1309   bool deduplicateICVGetters(InternalControlVar &ICV, Attributor &A) {
1310     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1311     auto &ICVInfo = OMPInfoCache.ICVs[ICV];
1312     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
1313 
1314     bool Changed = false;
1315 
1316     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1317       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
1318       Instruction *UserI = cast<Instruction>(U.getUser());
1319       Value *ReplVal = getReplacementValue(ICV, UserI, A);
1320 
1321       if (!ReplVal || !CI)
1322         return false;
1323 
1324       A.removeCallSite(CI);
1325       CI->replaceAllUsesWith(ReplVal);
1326       CI->eraseFromParent();
1327       Changed = true;
1328       return true;
1329     };
1330 
1331     GetterRFI.foreachUse(ReplaceAndDeleteCB, getAnchorScope());
1332     return Changed;
1333   }
1334 
1335   // Map of ICV to their values at specific program point.
1336   EnumeratedArray<SmallSetVector<ICVValue, 4>, InternalControlVar,
1337                   InternalControlVar::ICV___last>
1338       ICVValuesMap;
1339 
1340   // Currently only nthreads is being tracked.
1341   // this array will only grow with time.
1342   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
1343 
1344   ChangeStatus updateImpl(Attributor &A) override {
1345     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
1346 
1347     Function *F = getAnchorScope();
1348 
1349     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1350 
1351     for (InternalControlVar ICV : TrackableICVs) {
1352       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
1353 
1354       auto TrackValues = [&](Use &U, Function &) {
1355         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
1356         if (!CI)
1357           return false;
1358 
1359         // FIXME: handle setters with more that 1 arguments.
1360         /// Track new value.
1361         if (ICVValuesMap[ICV].insert(ICVValue(CI, CI->getArgOperand(0))))
1362           HasChanged = ChangeStatus::CHANGED;
1363 
1364         return false;
1365       };
1366 
1367       SetterRFI.foreachUse(TrackValues, F);
1368     }
1369 
1370     return HasChanged;
1371   }
1372 
1373   /// Return the value with which \p I can be replaced for specific \p ICV.
1374   Value *getReplacementValue(InternalControlVar ICV, const Instruction *I,
1375                              Attributor &A) override {
1376     const BasicBlock *CurrBB = I->getParent();
1377 
1378     auto &ValuesSet = ICVValuesMap[ICV];
1379     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1380     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
1381 
1382     for (const auto &ICVVal : ValuesSet) {
1383       if (CurrBB == ICVVal.Inst->getParent()) {
1384         if (!ICVVal.Inst->comesBefore(I))
1385           continue;
1386 
1387         // both instructions are in the same BB and at \p I we know the ICV
1388         // value.
1389         while (I != ICVVal.Inst) {
1390           // we don't yet know if a call might update an ICV.
1391           // TODO: check callsite AA for value.
1392           if (const auto *CB = dyn_cast<CallBase>(I))
1393             if (CB->getCalledFunction() != GetterRFI.Declaration)
1394               return nullptr;
1395 
1396           I = I->getPrevNode();
1397         }
1398 
1399         // No call in between, return the value.
1400         return ICVVal.TrackedValue;
1401       }
1402     }
1403 
1404     // No value was tracked.
1405     return nullptr;
1406   }
1407 };
1408 } // namespace
1409 
1410 const char AAICVTracker::ID = 0;
1411 
1412 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
1413                                               Attributor &A) {
1414   AAICVTracker *AA = nullptr;
1415   switch (IRP.getPositionKind()) {
1416   case IRPosition::IRP_INVALID:
1417   case IRPosition::IRP_FLOAT:
1418   case IRPosition::IRP_ARGUMENT:
1419   case IRPosition::IRP_RETURNED:
1420   case IRPosition::IRP_CALL_SITE_RETURNED:
1421   case IRPosition::IRP_CALL_SITE_ARGUMENT:
1422   case IRPosition::IRP_CALL_SITE:
1423     llvm_unreachable("ICVTracker can only be created for function position!");
1424   case IRPosition::IRP_FUNCTION:
1425     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
1426     break;
1427   }
1428 
1429   return *AA;
1430 }
1431 
1432 PreservedAnalyses OpenMPOptPass::run(LazyCallGraph::SCC &C,
1433                                      CGSCCAnalysisManager &AM,
1434                                      LazyCallGraph &CG, CGSCCUpdateResult &UR) {
1435   if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule))
1436     return PreservedAnalyses::all();
1437 
1438   if (DisableOpenMPOptimizations)
1439     return PreservedAnalyses::all();
1440 
1441   SmallVector<Function *, 16> SCC;
1442   // If there are kernels in the module, we have to run on all SCC's.
1443   bool SCCIsInteresting = !OMPInModule.getKernels().empty();
1444   for (LazyCallGraph::Node &N : C) {
1445     Function *Fn = &N.getFunction();
1446     SCC.push_back(Fn);
1447 
1448     // Do we already know that the SCC contains kernels,
1449     // or that OpenMP functions are called from this SCC?
1450     if (SCCIsInteresting)
1451       continue;
1452     // If not, let's check that.
1453     SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn);
1454   }
1455 
1456   if (!SCCIsInteresting || SCC.empty())
1457     return PreservedAnalyses::all();
1458 
1459   FunctionAnalysisManager &FAM =
1460       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
1461 
1462   AnalysisGetter AG(FAM);
1463 
1464   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
1465     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
1466   };
1467 
1468   CallGraphUpdater CGUpdater;
1469   CGUpdater.initialize(CG, C, AM, UR);
1470 
1471   SetVector<Function *> Functions(SCC.begin(), SCC.end());
1472   BumpPtrAllocator Allocator;
1473   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
1474                                 /*CGSCC*/ Functions, OMPInModule.getKernels());
1475 
1476   Attributor A(Functions, InfoCache, CGUpdater);
1477 
1478   // TODO: Compute the module slice we are allowed to look at.
1479   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
1480   bool Changed = OMPOpt.run();
1481   if (Changed)
1482     return PreservedAnalyses::none();
1483 
1484   return PreservedAnalyses::all();
1485 }
1486 
1487 namespace {
1488 
1489 struct OpenMPOptLegacyPass : public CallGraphSCCPass {
1490   CallGraphUpdater CGUpdater;
1491   OpenMPInModule OMPInModule;
1492   static char ID;
1493 
1494   OpenMPOptLegacyPass() : CallGraphSCCPass(ID) {
1495     initializeOpenMPOptLegacyPassPass(*PassRegistry::getPassRegistry());
1496   }
1497 
1498   void getAnalysisUsage(AnalysisUsage &AU) const override {
1499     CallGraphSCCPass::getAnalysisUsage(AU);
1500   }
1501 
1502   bool doInitialization(CallGraph &CG) override {
1503     // Disable the pass if there is no OpenMP (runtime call) in the module.
1504     containsOpenMP(CG.getModule(), OMPInModule);
1505     return false;
1506   }
1507 
1508   bool runOnSCC(CallGraphSCC &CGSCC) override {
1509     if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule))
1510       return false;
1511     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
1512       return false;
1513 
1514     SmallVector<Function *, 16> SCC;
1515     // If there are kernels in the module, we have to run on all SCC's.
1516     bool SCCIsInteresting = !OMPInModule.getKernels().empty();
1517     for (CallGraphNode *CGN : CGSCC) {
1518       Function *Fn = CGN->getFunction();
1519       if (!Fn || Fn->isDeclaration())
1520         continue;
1521       SCC.push_back(Fn);
1522 
1523       // Do we already know that the SCC contains kernels,
1524       // or that OpenMP functions are called from this SCC?
1525       if (SCCIsInteresting)
1526         continue;
1527       // If not, let's check that.
1528       SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn);
1529     }
1530 
1531     if (!SCCIsInteresting || SCC.empty())
1532       return false;
1533 
1534     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
1535     CGUpdater.initialize(CG, CGSCC);
1536 
1537     // Maintain a map of functions to avoid rebuilding the ORE
1538     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
1539     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
1540       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
1541       if (!ORE)
1542         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
1543       return *ORE;
1544     };
1545 
1546     AnalysisGetter AG;
1547     SetVector<Function *> Functions(SCC.begin(), SCC.end());
1548     BumpPtrAllocator Allocator;
1549     OMPInformationCache InfoCache(
1550         *(Functions.back()->getParent()), AG, Allocator,
1551         /*CGSCC*/ Functions, OMPInModule.getKernels());
1552 
1553     Attributor A(Functions, InfoCache, CGUpdater);
1554 
1555     // TODO: Compute the module slice we are allowed to look at.
1556     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
1557     return OMPOpt.run();
1558   }
1559 
1560   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
1561 };
1562 
1563 } // end anonymous namespace
1564 
1565 void OpenMPInModule::identifyKernels(Module &M) {
1566 
1567   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
1568   if (!MD)
1569     return;
1570 
1571   for (auto *Op : MD->operands()) {
1572     if (Op->getNumOperands() < 2)
1573       continue;
1574     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
1575     if (!KindID || KindID->getString() != "kernel")
1576       continue;
1577 
1578     Function *KernelFn =
1579         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
1580     if (!KernelFn)
1581       continue;
1582 
1583     ++NumOpenMPTargetRegionKernels;
1584 
1585     Kernels.insert(KernelFn);
1586   }
1587 }
1588 
1589 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) {
1590   if (OMPInModule.isKnown())
1591     return OMPInModule;
1592 
1593   auto RecordFunctionsContainingUsesOf = [&](Function *F) {
1594     for (User *U : F->users())
1595       if (auto *I = dyn_cast<Instruction>(U))
1596         OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction());
1597   };
1598 
1599   // MSVC doesn't like long if-else chains for some reason and instead just
1600   // issues an error. Work around it..
1601   do {
1602 #define OMP_RTL(_Enum, _Name, ...)                                             \
1603   if (Function *F = M.getFunction(_Name)) {                                    \
1604     RecordFunctionsContainingUsesOf(F);                                        \
1605     OMPInModule = true;                                                        \
1606   }
1607 #include "llvm/Frontend/OpenMP/OMPKinds.def"
1608   } while (false);
1609 
1610   // Identify kernels once. TODO: We should split the OMPInformationCache into a
1611   // module and an SCC part. The kernel information, among other things, could
1612   // go into the module part.
1613   if (OMPInModule.isKnown() && OMPInModule) {
1614     OMPInModule.identifyKernels(M);
1615     return true;
1616   }
1617 
1618   return OMPInModule = false;
1619 }
1620 
1621 char OpenMPOptLegacyPass::ID = 0;
1622 
1623 INITIALIZE_PASS_BEGIN(OpenMPOptLegacyPass, "openmpopt",
1624                       "OpenMP specific optimizations", false, false)
1625 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
1626 INITIALIZE_PASS_END(OpenMPOptLegacyPass, "openmpopt",
1627                     "OpenMP specific optimizations", false, false)
1628 
1629 Pass *llvm::createOpenMPOptLegacyPass() { return new OpenMPOptLegacyPass(); }
1630