1 //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "OffloadWrapper.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/Triple.h"
12 #include "llvm/IR/Constants.h"
13 #include "llvm/IR/GlobalVariable.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/Support/Error.h"
18 #include "llvm/Transforms/Utils/ModuleUtils.h"
19 
20 using namespace llvm;
21 
22 namespace {
23 /// Magic number that begins the section containing the CUDA fatbinary.
24 constexpr unsigned CudaFatMagic = 0x466243b1;
25 constexpr unsigned HIPFatMagic = 0x48495046;
26 
27 /// Copied from clang/CGCudaRuntime.h.
28 enum OffloadEntryKindFlag : uint32_t {
29   /// Mark the entry as a global entry. This indicates the presense of a
30   /// kernel if the size size field is zero and a variable otherwise.
31   OffloadGlobalEntry = 0x0,
32   /// Mark the entry as a managed global variable.
33   OffloadGlobalManagedEntry = 0x1,
34   /// Mark the entry as a surface variable.
35   OffloadGlobalSurfaceEntry = 0x2,
36   /// Mark the entry as a texture variable.
37   OffloadGlobalTextureEntry = 0x3,
38 };
39 
40 IntegerType *getSizeTTy(Module &M) {
41   LLVMContext &C = M.getContext();
42   switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) {
43   case 4u:
44     return Type::getInt32Ty(C);
45   case 8u:
46     return Type::getInt64Ty(C);
47   }
48   llvm_unreachable("unsupported pointer type size");
49 }
50 
51 // struct __tgt_offload_entry {
52 //   void *addr;
53 //   char *name;
54 //   size_t size;
55 //   int32_t flags;
56 //   int32_t reserved;
57 // };
58 StructType *getEntryTy(Module &M) {
59   LLVMContext &C = M.getContext();
60   StructType *EntryTy = StructType::getTypeByName(C, "__tgt_offload_entry");
61   if (!EntryTy)
62     EntryTy = StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C),
63                                  Type::getInt8PtrTy(C), getSizeTTy(M),
64                                  Type::getInt32Ty(C), Type::getInt32Ty(C));
65   return EntryTy;
66 }
67 
68 PointerType *getEntryPtrTy(Module &M) {
69   return PointerType::getUnqual(getEntryTy(M));
70 }
71 
72 // struct __tgt_device_image {
73 //   void *ImageStart;
74 //   void *ImageEnd;
75 //   __tgt_offload_entry *EntriesBegin;
76 //   __tgt_offload_entry *EntriesEnd;
77 // };
78 StructType *getDeviceImageTy(Module &M) {
79   LLVMContext &C = M.getContext();
80   StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image");
81   if (!ImageTy)
82     ImageTy = StructType::create("__tgt_device_image", Type::getInt8PtrTy(C),
83                                  Type::getInt8PtrTy(C), getEntryPtrTy(M),
84                                  getEntryPtrTy(M));
85   return ImageTy;
86 }
87 
88 PointerType *getDeviceImagePtrTy(Module &M) {
89   return PointerType::getUnqual(getDeviceImageTy(M));
90 }
91 
92 // struct __tgt_bin_desc {
93 //   int32_t NumDeviceImages;
94 //   __tgt_device_image *DeviceImages;
95 //   __tgt_offload_entry *HostEntriesBegin;
96 //   __tgt_offload_entry *HostEntriesEnd;
97 // };
98 StructType *getBinDescTy(Module &M) {
99   LLVMContext &C = M.getContext();
100   StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc");
101   if (!DescTy)
102     DescTy = StructType::create("__tgt_bin_desc", Type::getInt32Ty(C),
103                                 getDeviceImagePtrTy(M), getEntryPtrTy(M),
104                                 getEntryPtrTy(M));
105   return DescTy;
106 }
107 
108 PointerType *getBinDescPtrTy(Module &M) {
109   return PointerType::getUnqual(getBinDescTy(M));
110 }
111 
112 /// Creates binary descriptor for the given device images. Binary descriptor
113 /// is an object that is passed to the offloading runtime at program startup
114 /// and it describes all device images available in the executable or shared
115 /// library. It is defined as follows
116 ///
117 /// __attribute__((visibility("hidden")))
118 /// extern __tgt_offload_entry *__start_omp_offloading_entries;
119 /// __attribute__((visibility("hidden")))
120 /// extern __tgt_offload_entry *__stop_omp_offloading_entries;
121 ///
122 /// static const char Image0[] = { <Bufs.front() contents> };
123 ///  ...
124 /// static const char ImageN[] = { <Bufs.back() contents> };
125 ///
126 /// static const __tgt_device_image Images[] = {
127 ///   {
128 ///     Image0,                            /*ImageStart*/
129 ///     Image0 + sizeof(Image0),           /*ImageEnd*/
130 ///     __start_omp_offloading_entries,    /*EntriesBegin*/
131 ///     __stop_omp_offloading_entries      /*EntriesEnd*/
132 ///   },
133 ///   ...
134 ///   {
135 ///     ImageN,                            /*ImageStart*/
136 ///     ImageN + sizeof(ImageN),           /*ImageEnd*/
137 ///     __start_omp_offloading_entries,    /*EntriesBegin*/
138 ///     __stop_omp_offloading_entries      /*EntriesEnd*/
139 ///   }
140 /// };
141 ///
142 /// static const __tgt_bin_desc BinDesc = {
143 ///   sizeof(Images) / sizeof(Images[0]),  /*NumDeviceImages*/
144 ///   Images,                              /*DeviceImages*/
145 ///   __start_omp_offloading_entries,      /*HostEntriesBegin*/
146 ///   __stop_omp_offloading_entries        /*HostEntriesEnd*/
147 /// };
148 ///
149 /// Global variable that represents BinDesc is returned.
150 GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
151   LLVMContext &C = M.getContext();
152   // Create external begin/end symbols for the offload entries table.
153   auto *EntriesB = new GlobalVariable(
154       M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
155       /*Initializer*/ nullptr, "__start_omp_offloading_entries");
156   EntriesB->setVisibility(GlobalValue::HiddenVisibility);
157   auto *EntriesE = new GlobalVariable(
158       M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
159       /*Initializer*/ nullptr, "__stop_omp_offloading_entries");
160   EntriesE->setVisibility(GlobalValue::HiddenVisibility);
161 
162   // We assume that external begin/end symbols that we have created above will
163   // be defined by the linker. But linker will do that only if linker inputs
164   // have section with "omp_offloading_entries" name which is not guaranteed.
165   // So, we just create dummy zero sized object in the offload entries section
166   // to force linker to define those symbols.
167   auto *DummyInit =
168       ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
169   auto *DummyEntry = new GlobalVariable(
170       M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
171       "__dummy.omp_offloading.entry");
172   DummyEntry->setSection("omp_offloading_entries");
173   DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
174 
175   auto *Zero = ConstantInt::get(getSizeTTy(M), 0u);
176   Constant *ZeroZero[] = {Zero, Zero};
177 
178   // Create initializer for the images array.
179   SmallVector<Constant *, 4u> ImagesInits;
180   ImagesInits.reserve(Bufs.size());
181   for (ArrayRef<char> Buf : Bufs) {
182     auto *Data = ConstantDataArray::get(C, Buf);
183     auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
184                                      GlobalVariable::InternalLinkage, Data,
185                                      ".omp_offloading.device_image");
186     Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
187 
188     auto *Size = ConstantInt::get(getSizeTTy(M), Buf.size());
189     Constant *ZeroSize[] = {Zero, Size};
190 
191     auto *ImageB =
192         ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroZero);
193     auto *ImageE =
194         ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize);
195 
196     ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB,
197                                               ImageE, EntriesB, EntriesE));
198   }
199 
200   // Then create images array.
201   auto *ImagesData = ConstantArray::get(
202       ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits);
203 
204   auto *Images =
205       new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
206                          GlobalValue::InternalLinkage, ImagesData,
207                          ".omp_offloading.device_images");
208   Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
209 
210   auto *ImagesB =
211       ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero);
212 
213   // And finally create the binary descriptor object.
214   auto *DescInit = ConstantStruct::get(
215       getBinDescTy(M),
216       ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB,
217       EntriesB, EntriesE);
218 
219   return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
220                             GlobalValue::InternalLinkage, DescInit,
221                             ".omp_offloading.descriptor");
222 }
223 
224 void createRegisterFunction(Module &M, GlobalVariable *BinDesc) {
225   LLVMContext &C = M.getContext();
226   auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
227   auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
228                                 ".omp_offloading.descriptor_reg", &M);
229   Func->setSection(".text.startup");
230 
231   // Get __tgt_register_lib function declaration.
232   auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
233                                       /*isVarArg*/ false);
234   FunctionCallee RegFuncC =
235       M.getOrInsertFunction("__tgt_register_lib", RegFuncTy);
236 
237   // Construct function body
238   IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
239   Builder.CreateCall(RegFuncC, BinDesc);
240   Builder.CreateRetVoid();
241 
242   // Add this function to constructors.
243   // Set priority to 1 so that __tgt_register_lib is executed AFTER
244   // __tgt_register_requires (we want to know what requirements have been
245   // asked for before we load a libomptarget plugin so that by the time the
246   // plugin is loaded it can report how many devices there are which can
247   // satisfy these requirements).
248   appendToGlobalCtors(M, Func, /*Priority*/ 1);
249 }
250 
251 void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) {
252   LLVMContext &C = M.getContext();
253   auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
254   auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
255                                 ".omp_offloading.descriptor_unreg", &M);
256   Func->setSection(".text.startup");
257 
258   // Get __tgt_unregister_lib function declaration.
259   auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
260                                         /*isVarArg*/ false);
261   FunctionCallee UnRegFuncC =
262       M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy);
263 
264   // Construct function body
265   IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
266   Builder.CreateCall(UnRegFuncC, BinDesc);
267   Builder.CreateRetVoid();
268 
269   // Add this function to global destructors.
270   // Match priority of __tgt_register_lib
271   appendToGlobalDtors(M, Func, /*Priority*/ 1);
272 }
273 
274 // struct fatbin_wrapper {
275 //  int32_t magic;
276 //  int32_t version;
277 //  void *image;
278 //  void *reserved;
279 //};
280 StructType *getFatbinWrapperTy(Module &M) {
281   LLVMContext &C = M.getContext();
282   StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
283   if (!FatbinTy)
284     FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C),
285                                   Type::getInt32Ty(C), Type::getInt8PtrTy(C),
286                                   Type::getInt8PtrTy(C));
287   return FatbinTy;
288 }
289 
290 /// Embed the image \p Image into the module \p M so it can be found by the
291 /// runtime.
292 GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
293   LLVMContext &C = M.getContext();
294   llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C);
295   llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
296 
297   // Create the global string containing the fatbinary.
298   StringRef FatbinConstantSection =
299       IsHIP ? ".hip_fatbin"
300             : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
301   auto *Data = ConstantDataArray::get(C, Image);
302   auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
303                                     GlobalVariable::InternalLinkage, Data,
304                                     ".fatbin_image");
305   Fatbin->setSection(FatbinConstantSection);
306 
307   // Create the fatbinary wrapper
308   StringRef FatbinWrapperSection = IsHIP               ? ".hipFatBinSegment"
309                                    : Triple.isMacOSX() ? "__NV_CUDA,__fatbin"
310                                                        : ".nvFatBinSegment";
311   Constant *FatbinWrapper[] = {
312       ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic),
313       ConstantInt::get(Type::getInt32Ty(C), 1),
314       ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
315       ConstantPointerNull::get(Type::getInt8PtrTy(C))};
316 
317   Constant *FatbinInitializer =
318       ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
319 
320   auto *FatbinDesc =
321       new GlobalVariable(M, getFatbinWrapperTy(M),
322                          /*isConstant*/ true, GlobalValue::InternalLinkage,
323                          FatbinInitializer, ".fatbin_wrapper");
324   FatbinDesc->setSection(FatbinWrapperSection);
325   FatbinDesc->setAlignment(Align(8));
326 
327   // We create a dummy entry to ensure the linker will define the begin / end
328   // symbols. The CUDA runtime should ignore the null address if we attempt to
329   // register it.
330   auto *DummyInit =
331       ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
332   auto *DummyEntry = new GlobalVariable(
333       M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
334       IsHIP ? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry");
335   DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
336   DummyEntry->setSection(IsHIP ? "hip_offloading_entries"
337                                : "cuda_offloading_entries");
338 
339   return FatbinDesc;
340 }
341 
342 /// Create the register globals function. We will iterate all of the offloading
343 /// entries stored at the begin / end symbols and register them according to
344 /// their type. This creates the following function in IR:
345 ///
346 /// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
347 /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
348 ///
349 /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
350 ///                                    void *, void *, void *, void *, int *);
351 /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
352 ///                               int64_t, int32_t, int32_t);
353 ///
354 /// void __cudaRegisterTest(void **fatbinHandle) {
355 ///   for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
356 ///        entry != &__stop_cuda_offloading_entries; ++entry) {
357 ///     if (!entry->size)
358 ///       __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
359 ///                              entry->name, -1, 0, 0, 0, 0, 0);
360 ///     else
361 ///       __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
362 ///                         0, entry->size, 0, 0);
363 ///   }
364 /// }
365 Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
366   LLVMContext &C = M.getContext();
367   // Get the __cudaRegisterFunction function declaration.
368   auto *RegFuncTy = FunctionType::get(
369       Type::getInt32Ty(C),
370       {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
371        Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
372        Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C),
373        Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)},
374       /*isVarArg*/ false);
375   FunctionCallee RegFunc = M.getOrInsertFunction(
376       IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy);
377 
378   // Get the __cudaRegisterVar function declaration.
379   auto *RegVarTy = FunctionType::get(
380       Type::getVoidTy(C),
381       {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
382        Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
383        getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
384       /*isVarArg*/ false);
385   FunctionCallee RegVar = M.getOrInsertFunction(
386       IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy);
387 
388   // Create the references to the start / stop symbols defined by the linker.
389   auto *EntriesB =
390       new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
391                          /*isConstant*/ true, GlobalValue::ExternalLinkage,
392                          /*Initializer*/ nullptr,
393                          IsHIP ? "__start_hip_offloading_entries"
394                                : "__start_cuda_offloading_entries");
395   EntriesB->setVisibility(GlobalValue::HiddenVisibility);
396   auto *EntriesE =
397       new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
398                          /*isConstant*/ true, GlobalValue::ExternalLinkage,
399                          /*Initializer*/ nullptr,
400                          IsHIP ? "__stop_hip_offloading_entries"
401                                : "__stop_cuda_offloading_entries");
402   EntriesE->setVisibility(GlobalValue::HiddenVisibility);
403 
404   auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C),
405                                          Type::getInt8PtrTy(C)->getPointerTo(),
406                                          /*isVarArg*/ false);
407   auto *RegGlobalsFn =
408       Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage,
409                        IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M);
410   RegGlobalsFn->setSection(".text.startup");
411 
412   // Create the loop to register all the entries.
413   IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
414   auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
415   auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
416   auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
417   auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn);
418   auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn);
419   auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn);
420   auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn);
421   auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
422   auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
423 
424   auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE);
425   Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB);
426   Builder.SetInsertPoint(EntryBB);
427   auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry");
428   auto *AddrPtr =
429       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
430                                 {ConstantInt::get(getSizeTTy(M), 0),
431                                  ConstantInt::get(Type::getInt32Ty(C), 0)});
432   auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr");
433   auto *NamePtr =
434       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
435                                 {ConstantInt::get(getSizeTTy(M), 0),
436                                  ConstantInt::get(Type::getInt32Ty(C), 1)});
437   auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name");
438   auto *SizePtr =
439       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
440                                 {ConstantInt::get(getSizeTTy(M), 0),
441                                  ConstantInt::get(Type::getInt32Ty(C), 2)});
442   auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
443   auto *FlagsPtr =
444       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
445                                 {ConstantInt::get(getSizeTTy(M), 0),
446                                  ConstantInt::get(Type::getInt32Ty(C), 3)});
447   auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flag");
448   auto *FnCond =
449       Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
450   Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
451 
452   // Create kernel registration code.
453   Builder.SetInsertPoint(IfThenBB);
454   Builder.CreateCall(RegFunc,
455                      {RegGlobalsFn->arg_begin(), Addr, Name, Name,
456                       ConstantInt::get(Type::getInt32Ty(C), -1),
457                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
458                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
459                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
460                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
461                       ConstantPointerNull::get(Type::getInt32PtrTy(C))});
462   Builder.CreateBr(IfEndBB);
463   Builder.SetInsertPoint(IfElseBB);
464 
465   auto *Switch = Builder.CreateSwitch(Flags, IfEndBB);
466   // Create global variable registration code.
467   Builder.SetInsertPoint(SwGlobalBB);
468   Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
469                               ConstantInt::get(Type::getInt32Ty(C), 0), Size,
470                               ConstantInt::get(Type::getInt32Ty(C), 0),
471                               ConstantInt::get(Type::getInt32Ty(C), 0)});
472   Builder.CreateBr(IfEndBB);
473   Switch->addCase(Builder.getInt32(OffloadGlobalEntry), SwGlobalBB);
474 
475   // Create managed variable registration code.
476   Builder.SetInsertPoint(SwManagedBB);
477   Builder.CreateBr(IfEndBB);
478   Switch->addCase(Builder.getInt32(OffloadGlobalManagedEntry), SwManagedBB);
479 
480   // Create surface variable registration code.
481   Builder.SetInsertPoint(SwSurfaceBB);
482   Builder.CreateBr(IfEndBB);
483   Switch->addCase(Builder.getInt32(OffloadGlobalSurfaceEntry), SwSurfaceBB);
484 
485   // Create texture variable registration code.
486   Builder.SetInsertPoint(SwTextureBB);
487   Builder.CreateBr(IfEndBB);
488   Switch->addCase(Builder.getInt32(OffloadGlobalTextureEntry), SwTextureBB);
489 
490   Builder.SetInsertPoint(IfEndBB);
491   auto *NewEntry = Builder.CreateInBoundsGEP(
492       getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
493   auto *Cmp = Builder.CreateICmpEQ(
494       NewEntry,
495       ConstantExpr::getInBoundsGetElementPtr(
496           ArrayType::get(getEntryTy(M), 0), EntriesE,
497           ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
498                                 ConstantInt::get(getSizeTTy(M), 0)})));
499   Entry->addIncoming(
500       ConstantExpr::getInBoundsGetElementPtr(
501           ArrayType::get(getEntryTy(M), 0), EntriesB,
502           ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
503                                 ConstantInt::get(getSizeTTy(M), 0)})),
504       &RegGlobalsFn->getEntryBlock());
505   Entry->addIncoming(NewEntry, IfEndBB);
506   Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
507   Builder.SetInsertPoint(ExitBB);
508   Builder.CreateRetVoid();
509 
510   return RegGlobalsFn;
511 }
512 
513 // Create the constructor and destructor to register the fatbinary with the CUDA
514 // runtime.
515 void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
516                                   bool IsHIP) {
517   LLVMContext &C = M.getContext();
518   auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
519   auto *CtorFunc =
520       Function::Create(CtorFuncTy, GlobalValue::InternalLinkage,
521                        IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M);
522   CtorFunc->setSection(".text.startup");
523 
524   auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
525   auto *DtorFunc =
526       Function::Create(DtorFuncTy, GlobalValue::InternalLinkage,
527                        IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M);
528   DtorFunc->setSection(".text.startup");
529 
530   // Get the __cudaRegisterFatBinary function declaration.
531   auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(),
532                                      Type::getInt8PtrTy(C),
533                                      /*isVarArg*/ false);
534   FunctionCallee RegFatbin = M.getOrInsertFunction(
535       IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy);
536   // Get the __cudaRegisterFatBinaryEnd function declaration.
537   auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C),
538                                         Type::getInt8PtrTy(C)->getPointerTo(),
539                                         /*isVarArg*/ false);
540   FunctionCallee RegFatbinEnd =
541       M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
542   // Get the __cudaUnregisterFatBinary function declaration.
543   auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C),
544                                        Type::getInt8PtrTy(C)->getPointerTo(),
545                                        /*isVarArg*/ false);
546   FunctionCallee UnregFatbin = M.getOrInsertFunction(
547       IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
548       UnregFatTy);
549 
550   auto *AtExitTy =
551       FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(),
552                         /*isVarArg*/ false);
553   FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
554 
555   auto *BinaryHandleGlobal = new llvm::GlobalVariable(
556       M, Type::getInt8PtrTy(C)->getPointerTo(), false,
557       llvm::GlobalValue::InternalLinkage,
558       llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()),
559       IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle");
560 
561   // Create the constructor to register this image with the runtime.
562   IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
563   CallInst *Handle = CtorBuilder.CreateCall(
564       RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast(
565                      FatbinDesc, Type::getInt8PtrTy(C)));
566   CtorBuilder.CreateAlignedStore(
567       Handle, BinaryHandleGlobal,
568       Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
569   CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle);
570   if (!IsHIP)
571     CtorBuilder.CreateCall(RegFatbinEnd, Handle);
572   CtorBuilder.CreateCall(AtExit, DtorFunc);
573   CtorBuilder.CreateRetVoid();
574 
575   // Create the destructor to unregister the image with the runtime. We cannot
576   // use a standard global destructor after CUDA 9.2 so this must be called by
577   // `atexit()` intead.
578   IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
579   LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
580       Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal,
581       Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
582   DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
583   DtorBuilder.CreateRetVoid();
584 
585   // Add this function to constructors.
586   appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1);
587 }
588 
589 } // namespace
590 
591 Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
592   GlobalVariable *Desc = createBinDesc(M, Images);
593   if (!Desc)
594     return createStringError(inconvertibleErrorCode(),
595                              "No binary descriptors created.");
596   createRegisterFunction(M, Desc);
597   createUnregisterFunction(M, Desc);
598   return Error::success();
599 }
600 
601 Error wrapCudaBinary(Module &M, ArrayRef<char> Image) {
602   GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false);
603   if (!Desc)
604     return createStringError(inconvertibleErrorCode(),
605                              "No fatinbary section created.");
606 
607   createRegisterFatbinFunction(M, Desc, /* IsHIP */ false);
608   return Error::success();
609 }
610 
611 Error wrapHIPBinary(Module &M, ArrayRef<char> Image) {
612   GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true);
613   if (!Desc)
614     return createStringError(inconvertibleErrorCode(),
615                              "No fatinbary section created.");
616 
617   createRegisterFatbinFunction(M, Desc, /* IsHIP */ true);
618   return Error::success();
619 }
620