1 //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "OffloadWrapper.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/ADT/Triple.h" 12 #include "llvm/IR/Constants.h" 13 #include "llvm/IR/GlobalVariable.h" 14 #include "llvm/IR/IRBuilder.h" 15 #include "llvm/IR/LLVMContext.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/Support/Error.h" 18 #include "llvm/Transforms/Utils/ModuleUtils.h" 19 20 using namespace llvm; 21 22 namespace { 23 /// Magic number that begins the section containing the CUDA fatbinary. 24 constexpr unsigned CudaFatMagic = 0x466243b1; 25 constexpr unsigned HIPFatMagic = 0x48495046; 26 27 /// Copied from clang/CGCudaRuntime.h. 28 enum OffloadEntryKindFlag : uint32_t { 29 /// Mark the entry as a global entry. This indicates the presense of a 30 /// kernel if the size size field is zero and a variable otherwise. 31 OffloadGlobalEntry = 0x0, 32 /// Mark the entry as a managed global variable. 33 OffloadGlobalManagedEntry = 0x1, 34 /// Mark the entry as a surface variable. 35 OffloadGlobalSurfaceEntry = 0x2, 36 /// Mark the entry as a texture variable. 37 OffloadGlobalTextureEntry = 0x3, 38 }; 39 40 IntegerType *getSizeTTy(Module &M) { 41 LLVMContext &C = M.getContext(); 42 switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) { 43 case 4u: 44 return Type::getInt32Ty(C); 45 case 8u: 46 return Type::getInt64Ty(C); 47 } 48 llvm_unreachable("unsupported pointer type size"); 49 } 50 51 // struct __tgt_offload_entry { 52 // void *addr; 53 // char *name; 54 // size_t size; 55 // int32_t flags; 56 // int32_t reserved; 57 // }; 58 StructType *getEntryTy(Module &M) { 59 LLVMContext &C = M.getContext(); 60 StructType *EntryTy = StructType::getTypeByName(C, "__tgt_offload_entry"); 61 if (!EntryTy) 62 EntryTy = StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C), 63 Type::getInt8PtrTy(C), getSizeTTy(M), 64 Type::getInt32Ty(C), Type::getInt32Ty(C)); 65 return EntryTy; 66 } 67 68 PointerType *getEntryPtrTy(Module &M) { 69 return PointerType::getUnqual(getEntryTy(M)); 70 } 71 72 // struct __tgt_device_image { 73 // void *ImageStart; 74 // void *ImageEnd; 75 // __tgt_offload_entry *EntriesBegin; 76 // __tgt_offload_entry *EntriesEnd; 77 // }; 78 StructType *getDeviceImageTy(Module &M) { 79 LLVMContext &C = M.getContext(); 80 StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image"); 81 if (!ImageTy) 82 ImageTy = StructType::create("__tgt_device_image", Type::getInt8PtrTy(C), 83 Type::getInt8PtrTy(C), getEntryPtrTy(M), 84 getEntryPtrTy(M)); 85 return ImageTy; 86 } 87 88 PointerType *getDeviceImagePtrTy(Module &M) { 89 return PointerType::getUnqual(getDeviceImageTy(M)); 90 } 91 92 // struct __tgt_bin_desc { 93 // int32_t NumDeviceImages; 94 // __tgt_device_image *DeviceImages; 95 // __tgt_offload_entry *HostEntriesBegin; 96 // __tgt_offload_entry *HostEntriesEnd; 97 // }; 98 StructType *getBinDescTy(Module &M) { 99 LLVMContext &C = M.getContext(); 100 StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc"); 101 if (!DescTy) 102 DescTy = StructType::create("__tgt_bin_desc", Type::getInt32Ty(C), 103 getDeviceImagePtrTy(M), getEntryPtrTy(M), 104 getEntryPtrTy(M)); 105 return DescTy; 106 } 107 108 PointerType *getBinDescPtrTy(Module &M) { 109 return PointerType::getUnqual(getBinDescTy(M)); 110 } 111 112 /// Creates binary descriptor for the given device images. Binary descriptor 113 /// is an object that is passed to the offloading runtime at program startup 114 /// and it describes all device images available in the executable or shared 115 /// library. It is defined as follows 116 /// 117 /// __attribute__((visibility("hidden"))) 118 /// extern __tgt_offload_entry *__start_omp_offloading_entries; 119 /// __attribute__((visibility("hidden"))) 120 /// extern __tgt_offload_entry *__stop_omp_offloading_entries; 121 /// 122 /// static const char Image0[] = { <Bufs.front() contents> }; 123 /// ... 124 /// static const char ImageN[] = { <Bufs.back() contents> }; 125 /// 126 /// static const __tgt_device_image Images[] = { 127 /// { 128 /// Image0, /*ImageStart*/ 129 /// Image0 + sizeof(Image0), /*ImageEnd*/ 130 /// __start_omp_offloading_entries, /*EntriesBegin*/ 131 /// __stop_omp_offloading_entries /*EntriesEnd*/ 132 /// }, 133 /// ... 134 /// { 135 /// ImageN, /*ImageStart*/ 136 /// ImageN + sizeof(ImageN), /*ImageEnd*/ 137 /// __start_omp_offloading_entries, /*EntriesBegin*/ 138 /// __stop_omp_offloading_entries /*EntriesEnd*/ 139 /// } 140 /// }; 141 /// 142 /// static const __tgt_bin_desc BinDesc = { 143 /// sizeof(Images) / sizeof(Images[0]), /*NumDeviceImages*/ 144 /// Images, /*DeviceImages*/ 145 /// __start_omp_offloading_entries, /*HostEntriesBegin*/ 146 /// __stop_omp_offloading_entries /*HostEntriesEnd*/ 147 /// }; 148 /// 149 /// Global variable that represents BinDesc is returned. 150 GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) { 151 LLVMContext &C = M.getContext(); 152 // Create external begin/end symbols for the offload entries table. 153 auto *EntriesB = new GlobalVariable( 154 M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage, 155 /*Initializer*/ nullptr, "__start_omp_offloading_entries"); 156 EntriesB->setVisibility(GlobalValue::HiddenVisibility); 157 auto *EntriesE = new GlobalVariable( 158 M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage, 159 /*Initializer*/ nullptr, "__stop_omp_offloading_entries"); 160 EntriesE->setVisibility(GlobalValue::HiddenVisibility); 161 162 // We assume that external begin/end symbols that we have created above will 163 // be defined by the linker. But linker will do that only if linker inputs 164 // have section with "omp_offloading_entries" name which is not guaranteed. 165 // So, we just create dummy zero sized object in the offload entries section 166 // to force linker to define those symbols. 167 auto *DummyInit = 168 ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u)); 169 auto *DummyEntry = new GlobalVariable( 170 M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit, 171 "__dummy.omp_offloading.entry"); 172 DummyEntry->setSection("omp_offloading_entries"); 173 DummyEntry->setVisibility(GlobalValue::HiddenVisibility); 174 175 auto *Zero = ConstantInt::get(getSizeTTy(M), 0u); 176 Constant *ZeroZero[] = {Zero, Zero}; 177 178 // Create initializer for the images array. 179 SmallVector<Constant *, 4u> ImagesInits; 180 ImagesInits.reserve(Bufs.size()); 181 for (ArrayRef<char> Buf : Bufs) { 182 auto *Data = ConstantDataArray::get(C, Buf); 183 auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true, 184 GlobalVariable::InternalLinkage, Data, 185 ".omp_offloading.device_image"); 186 Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); 187 188 auto *Size = ConstantInt::get(getSizeTTy(M), Buf.size()); 189 Constant *ZeroSize[] = {Zero, Size}; 190 191 auto *ImageB = 192 ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroZero); 193 auto *ImageE = 194 ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize); 195 196 ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB, 197 ImageE, EntriesB, EntriesE)); 198 } 199 200 // Then create images array. 201 auto *ImagesData = ConstantArray::get( 202 ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits); 203 204 auto *Images = 205 new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true, 206 GlobalValue::InternalLinkage, ImagesData, 207 ".omp_offloading.device_images"); 208 Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); 209 210 auto *ImagesB = 211 ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero); 212 213 // And finally create the binary descriptor object. 214 auto *DescInit = ConstantStruct::get( 215 getBinDescTy(M), 216 ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB, 217 EntriesB, EntriesE); 218 219 return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true, 220 GlobalValue::InternalLinkage, DescInit, 221 ".omp_offloading.descriptor"); 222 } 223 224 void createRegisterFunction(Module &M, GlobalVariable *BinDesc) { 225 LLVMContext &C = M.getContext(); 226 auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); 227 auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage, 228 ".omp_offloading.descriptor_reg", &M); 229 Func->setSection(".text.startup"); 230 231 // Get __tgt_register_lib function declaration. 232 auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M), 233 /*isVarArg*/ false); 234 FunctionCallee RegFuncC = 235 M.getOrInsertFunction("__tgt_register_lib", RegFuncTy); 236 237 // Construct function body 238 IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func)); 239 Builder.CreateCall(RegFuncC, BinDesc); 240 Builder.CreateRetVoid(); 241 242 // Add this function to constructors. 243 // Set priority to 1 so that __tgt_register_lib is executed AFTER 244 // __tgt_register_requires (we want to know what requirements have been 245 // asked for before we load a libomptarget plugin so that by the time the 246 // plugin is loaded it can report how many devices there are which can 247 // satisfy these requirements). 248 appendToGlobalCtors(M, Func, /*Priority*/ 1); 249 } 250 251 void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) { 252 LLVMContext &C = M.getContext(); 253 auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); 254 auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage, 255 ".omp_offloading.descriptor_unreg", &M); 256 Func->setSection(".text.startup"); 257 258 // Get __tgt_unregister_lib function declaration. 259 auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M), 260 /*isVarArg*/ false); 261 FunctionCallee UnRegFuncC = 262 M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy); 263 264 // Construct function body 265 IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func)); 266 Builder.CreateCall(UnRegFuncC, BinDesc); 267 Builder.CreateRetVoid(); 268 269 // Add this function to global destructors. 270 // Match priority of __tgt_register_lib 271 appendToGlobalDtors(M, Func, /*Priority*/ 1); 272 } 273 274 // struct fatbin_wrapper { 275 // int32_t magic; 276 // int32_t version; 277 // void *image; 278 // void *reserved; 279 //}; 280 StructType *getFatbinWrapperTy(Module &M) { 281 LLVMContext &C = M.getContext(); 282 StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper"); 283 if (!FatbinTy) 284 FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C), 285 Type::getInt32Ty(C), Type::getInt8PtrTy(C), 286 Type::getInt8PtrTy(C)); 287 return FatbinTy; 288 } 289 290 /// Embed the image \p Image into the module \p M so it can be found by the 291 /// runtime. 292 GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) { 293 LLVMContext &C = M.getContext(); 294 llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C); 295 llvm::Triple Triple = llvm::Triple(M.getTargetTriple()); 296 297 // Create the global string containing the fatbinary. 298 StringRef FatbinConstantSection = 299 IsHIP ? ".hip_fatbin" 300 : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin"); 301 auto *Data = ConstantDataArray::get(C, Image); 302 auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true, 303 GlobalVariable::InternalLinkage, Data, 304 ".fatbin_image"); 305 Fatbin->setSection(FatbinConstantSection); 306 307 // Create the fatbinary wrapper 308 StringRef FatbinWrapperSection = IsHIP ? ".hipFatBinSegment" 309 : Triple.isMacOSX() ? "__NV_CUDA,__fatbin" 310 : ".nvFatBinSegment"; 311 Constant *FatbinWrapper[] = { 312 ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic), 313 ConstantInt::get(Type::getInt32Ty(C), 1), 314 ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy), 315 ConstantPointerNull::get(Type::getInt8PtrTy(C))}; 316 317 Constant *FatbinInitializer = 318 ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper); 319 320 auto *FatbinDesc = 321 new GlobalVariable(M, getFatbinWrapperTy(M), 322 /*isConstant*/ true, GlobalValue::InternalLinkage, 323 FatbinInitializer, ".fatbin_wrapper"); 324 FatbinDesc->setSection(FatbinWrapperSection); 325 FatbinDesc->setAlignment(Align(8)); 326 327 // We create a dummy entry to ensure the linker will define the begin / end 328 // symbols. The CUDA runtime should ignore the null address if we attempt to 329 // register it. 330 auto *DummyInit = 331 ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u)); 332 auto *DummyEntry = new GlobalVariable( 333 M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit, 334 IsHIP ? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry"); 335 DummyEntry->setVisibility(GlobalValue::HiddenVisibility); 336 DummyEntry->setSection(IsHIP ? "hip_offloading_entries" 337 : "cuda_offloading_entries"); 338 339 return FatbinDesc; 340 } 341 342 /// Create the register globals function. We will iterate all of the offloading 343 /// entries stored at the begin / end symbols and register them according to 344 /// their type. This creates the following function in IR: 345 /// 346 /// extern struct __tgt_offload_entry __start_cuda_offloading_entries; 347 /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries; 348 /// 349 /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int, 350 /// void *, void *, void *, void *, int *); 351 /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t, 352 /// int64_t, int32_t, int32_t); 353 /// 354 /// void __cudaRegisterTest(void **fatbinHandle) { 355 /// for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries; 356 /// entry != &__stop_cuda_offloading_entries; ++entry) { 357 /// if (!entry->size) 358 /// __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name, 359 /// entry->name, -1, 0, 0, 0, 0, 0); 360 /// else 361 /// __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name, 362 /// 0, entry->size, 0, 0); 363 /// } 364 /// } 365 Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) { 366 LLVMContext &C = M.getContext(); 367 // Get the __cudaRegisterFunction function declaration. 368 auto *RegFuncTy = FunctionType::get( 369 Type::getInt32Ty(C), 370 {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C), 371 Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C), 372 Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), 373 Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)}, 374 /*isVarArg*/ false); 375 FunctionCallee RegFunc = M.getOrInsertFunction( 376 IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy); 377 378 // Get the __cudaRegisterVar function declaration. 379 auto *RegVarTy = FunctionType::get( 380 Type::getVoidTy(C), 381 {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C), 382 Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C), 383 getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)}, 384 /*isVarArg*/ false); 385 FunctionCallee RegVar = M.getOrInsertFunction( 386 IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy); 387 388 // Create the references to the start / stop symbols defined by the linker. 389 auto *EntriesB = 390 new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0), 391 /*isConstant*/ true, GlobalValue::ExternalLinkage, 392 /*Initializer*/ nullptr, 393 IsHIP ? "__start_hip_offloading_entries" 394 : "__start_cuda_offloading_entries"); 395 EntriesB->setVisibility(GlobalValue::HiddenVisibility); 396 auto *EntriesE = 397 new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0), 398 /*isConstant*/ true, GlobalValue::ExternalLinkage, 399 /*Initializer*/ nullptr, 400 IsHIP ? "__stop_hip_offloading_entries" 401 : "__stop_cuda_offloading_entries"); 402 EntriesE->setVisibility(GlobalValue::HiddenVisibility); 403 404 auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C), 405 Type::getInt8PtrTy(C)->getPointerTo(), 406 /*isVarArg*/ false); 407 auto *RegGlobalsFn = 408 Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage, 409 IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M); 410 RegGlobalsFn->setSection(".text.startup"); 411 412 // Create the loop to register all the entries. 413 IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn)); 414 auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn); 415 auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn); 416 auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn); 417 auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn); 418 auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn); 419 auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn); 420 auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn); 421 auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn); 422 auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn); 423 424 auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE); 425 Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB); 426 Builder.SetInsertPoint(EntryBB); 427 auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry"); 428 auto *AddrPtr = 429 Builder.CreateInBoundsGEP(getEntryTy(M), Entry, 430 {ConstantInt::get(getSizeTTy(M), 0), 431 ConstantInt::get(Type::getInt32Ty(C), 0)}); 432 auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr"); 433 auto *NamePtr = 434 Builder.CreateInBoundsGEP(getEntryTy(M), Entry, 435 {ConstantInt::get(getSizeTTy(M), 0), 436 ConstantInt::get(Type::getInt32Ty(C), 1)}); 437 auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name"); 438 auto *SizePtr = 439 Builder.CreateInBoundsGEP(getEntryTy(M), Entry, 440 {ConstantInt::get(getSizeTTy(M), 0), 441 ConstantInt::get(Type::getInt32Ty(C), 2)}); 442 auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size"); 443 auto *FlagsPtr = 444 Builder.CreateInBoundsGEP(getEntryTy(M), Entry, 445 {ConstantInt::get(getSizeTTy(M), 0), 446 ConstantInt::get(Type::getInt32Ty(C), 3)}); 447 auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flag"); 448 auto *FnCond = 449 Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M))); 450 Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB); 451 452 // Create kernel registration code. 453 Builder.SetInsertPoint(IfThenBB); 454 Builder.CreateCall(RegFunc, 455 {RegGlobalsFn->arg_begin(), Addr, Name, Name, 456 ConstantInt::get(Type::getInt32Ty(C), -1), 457 ConstantPointerNull::get(Type::getInt8PtrTy(C)), 458 ConstantPointerNull::get(Type::getInt8PtrTy(C)), 459 ConstantPointerNull::get(Type::getInt8PtrTy(C)), 460 ConstantPointerNull::get(Type::getInt8PtrTy(C)), 461 ConstantPointerNull::get(Type::getInt32PtrTy(C))}); 462 Builder.CreateBr(IfEndBB); 463 Builder.SetInsertPoint(IfElseBB); 464 465 auto *Switch = Builder.CreateSwitch(Flags, IfEndBB); 466 // Create global variable registration code. 467 Builder.SetInsertPoint(SwGlobalBB); 468 Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name, 469 ConstantInt::get(Type::getInt32Ty(C), 0), Size, 470 ConstantInt::get(Type::getInt32Ty(C), 0), 471 ConstantInt::get(Type::getInt32Ty(C), 0)}); 472 Builder.CreateBr(IfEndBB); 473 Switch->addCase(Builder.getInt32(OffloadGlobalEntry), SwGlobalBB); 474 475 // Create managed variable registration code. 476 Builder.SetInsertPoint(SwManagedBB); 477 Builder.CreateBr(IfEndBB); 478 Switch->addCase(Builder.getInt32(OffloadGlobalManagedEntry), SwManagedBB); 479 480 // Create surface variable registration code. 481 Builder.SetInsertPoint(SwSurfaceBB); 482 Builder.CreateBr(IfEndBB); 483 Switch->addCase(Builder.getInt32(OffloadGlobalSurfaceEntry), SwSurfaceBB); 484 485 // Create texture variable registration code. 486 Builder.SetInsertPoint(SwTextureBB); 487 Builder.CreateBr(IfEndBB); 488 Switch->addCase(Builder.getInt32(OffloadGlobalTextureEntry), SwTextureBB); 489 490 Builder.SetInsertPoint(IfEndBB); 491 auto *NewEntry = Builder.CreateInBoundsGEP( 492 getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1)); 493 auto *Cmp = Builder.CreateICmpEQ( 494 NewEntry, 495 ConstantExpr::getInBoundsGetElementPtr( 496 ArrayType::get(getEntryTy(M), 0), EntriesE, 497 ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0), 498 ConstantInt::get(getSizeTTy(M), 0)}))); 499 Entry->addIncoming( 500 ConstantExpr::getInBoundsGetElementPtr( 501 ArrayType::get(getEntryTy(M), 0), EntriesB, 502 ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0), 503 ConstantInt::get(getSizeTTy(M), 0)})), 504 &RegGlobalsFn->getEntryBlock()); 505 Entry->addIncoming(NewEntry, IfEndBB); 506 Builder.CreateCondBr(Cmp, ExitBB, EntryBB); 507 Builder.SetInsertPoint(ExitBB); 508 Builder.CreateRetVoid(); 509 510 return RegGlobalsFn; 511 } 512 513 // Create the constructor and destructor to register the fatbinary with the CUDA 514 // runtime. 515 void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc, 516 bool IsHIP) { 517 LLVMContext &C = M.getContext(); 518 auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); 519 auto *CtorFunc = 520 Function::Create(CtorFuncTy, GlobalValue::InternalLinkage, 521 IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M); 522 CtorFunc->setSection(".text.startup"); 523 524 auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); 525 auto *DtorFunc = 526 Function::Create(DtorFuncTy, GlobalValue::InternalLinkage, 527 IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M); 528 DtorFunc->setSection(".text.startup"); 529 530 // Get the __cudaRegisterFatBinary function declaration. 531 auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(), 532 Type::getInt8PtrTy(C), 533 /*isVarArg*/ false); 534 FunctionCallee RegFatbin = M.getOrInsertFunction( 535 IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy); 536 // Get the __cudaRegisterFatBinaryEnd function declaration. 537 auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C), 538 Type::getInt8PtrTy(C)->getPointerTo(), 539 /*isVarArg*/ false); 540 FunctionCallee RegFatbinEnd = 541 M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy); 542 // Get the __cudaUnregisterFatBinary function declaration. 543 auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C), 544 Type::getInt8PtrTy(C)->getPointerTo(), 545 /*isVarArg*/ false); 546 FunctionCallee UnregFatbin = M.getOrInsertFunction( 547 IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary", 548 UnregFatTy); 549 550 auto *AtExitTy = 551 FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(), 552 /*isVarArg*/ false); 553 FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy); 554 555 auto *BinaryHandleGlobal = new llvm::GlobalVariable( 556 M, Type::getInt8PtrTy(C)->getPointerTo(), false, 557 llvm::GlobalValue::InternalLinkage, 558 llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()), 559 IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle"); 560 561 // Create the constructor to register this image with the runtime. 562 IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc)); 563 CallInst *Handle = CtorBuilder.CreateCall( 564 RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast( 565 FatbinDesc, Type::getInt8PtrTy(C))); 566 CtorBuilder.CreateAlignedStore( 567 Handle, BinaryHandleGlobal, 568 Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C)))); 569 CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle); 570 if (!IsHIP) 571 CtorBuilder.CreateCall(RegFatbinEnd, Handle); 572 CtorBuilder.CreateCall(AtExit, DtorFunc); 573 CtorBuilder.CreateRetVoid(); 574 575 // Create the destructor to unregister the image with the runtime. We cannot 576 // use a standard global destructor after CUDA 9.2 so this must be called by 577 // `atexit()` intead. 578 IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc)); 579 LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad( 580 Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal, 581 Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C)))); 582 DtorBuilder.CreateCall(UnregFatbin, BinaryHandle); 583 DtorBuilder.CreateRetVoid(); 584 585 // Add this function to constructors. 586 appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1); 587 } 588 589 } // namespace 590 591 Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) { 592 GlobalVariable *Desc = createBinDesc(M, Images); 593 if (!Desc) 594 return createStringError(inconvertibleErrorCode(), 595 "No binary descriptors created."); 596 createRegisterFunction(M, Desc); 597 createUnregisterFunction(M, Desc); 598 return Error::success(); 599 } 600 601 Error wrapCudaBinary(Module &M, ArrayRef<char> Image) { 602 GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false); 603 if (!Desc) 604 return createStringError(inconvertibleErrorCode(), 605 "No fatinbary section created."); 606 607 createRegisterFatbinFunction(M, Desc, /* IsHIP */ false); 608 return Error::success(); 609 } 610 611 Error wrapHIPBinary(Module &M, ArrayRef<char> Image) { 612 GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true); 613 if (!Desc) 614 return createStringError(inconvertibleErrorCode(), 615 "No fatinbary section created."); 616 617 createRegisterFatbinFunction(M, Desc, /* IsHIP */ true); 618 return Error::success(); 619 } 620