1 //===------- Utils.cpp - OpenMP device runtime utility functions -- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // 10 //===----------------------------------------------------------------------===// 11 12 #include "Utils.h" 13 14 #include "Debug.h" 15 #include "Interface.h" 16 #include "Mapping.h" 17 18 #pragma omp begin declare target device_type(nohost) 19 20 using namespace _OMP; 21 22 namespace _OMP { 23 /// Helper to keep code alive without introducing a performance penalty. 24 __attribute__((weak, optnone, cold)) KEEP_ALIVE void keepAlive() { 25 __kmpc_get_hardware_thread_id_in_block(); 26 __kmpc_get_hardware_num_threads_in_block(); 27 __kmpc_get_warp_size(); 28 __kmpc_barrier_simple_spmd(nullptr, 0); 29 __kmpc_barrier_simple_generic(nullptr, 0); 30 } 31 } // namespace _OMP 32 33 namespace impl { 34 35 void Unpack(uint64_t Val, uint32_t *LowBits, uint32_t *HighBits); 36 uint64_t Pack(uint32_t LowBits, uint32_t HighBits); 37 38 /// AMDGCN Implementation 39 /// 40 ///{ 41 #pragma omp begin declare variant match(device = {arch(amdgcn)}) 42 43 void Unpack(uint64_t Val, uint32_t *LowBits, uint32_t *HighBits) { 44 static_assert(sizeof(unsigned long) == 8, ""); 45 *LowBits = (uint32_t)(Val & 0x00000000FFFFFFFFUL); 46 *HighBits = (uint32_t)((Val & 0xFFFFFFFF00000000UL) >> 32); 47 } 48 49 uint64_t Pack(uint32_t LowBits, uint32_t HighBits) { 50 return (((uint64_t)HighBits) << 32) | (uint64_t)LowBits; 51 } 52 53 #pragma omp end declare variant 54 55 /// NVPTX Implementation 56 /// 57 ///{ 58 #pragma omp begin declare variant match( \ 59 device = {arch(nvptx, nvptx64)}, implementation = {extension(match_any)}) 60 61 void Unpack(uint64_t Val, uint32_t *LowBits, uint32_t *HighBits) { 62 uint32_t LowBitsLocal, HighBitsLocal; 63 asm("mov.b64 {%0,%1}, %2;" 64 : "=r"(LowBitsLocal), "=r"(HighBitsLocal) 65 : "l"(Val)); 66 *LowBits = LowBitsLocal; 67 *HighBits = HighBitsLocal; 68 } 69 70 uint64_t Pack(uint32_t LowBits, uint32_t HighBits) { 71 uint64_t Val; 72 asm("mov.b64 %0, {%1,%2};" : "=l"(Val) : "r"(LowBits), "r"(HighBits)); 73 return Val; 74 } 75 76 #pragma omp end declare variant 77 78 int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane); 79 int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t LaneDelta, 80 int32_t Width); 81 82 /// AMDGCN Implementation 83 /// 84 ///{ 85 #pragma omp begin declare variant match(device = {arch(amdgcn)}) 86 87 int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane) { 88 int Width = mapping::getWarpSize(); 89 int Self = mapping::getThreadIdInWarp(); 90 int Index = SrcLane + (Self & ~(Width - 1)); 91 return __builtin_amdgcn_ds_bpermute(Index << 2, Var); 92 } 93 94 int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t LaneDelta, 95 int32_t Width) { 96 int Self = mapping::getThreadIdInWarp(); 97 int Index = Self + LaneDelta; 98 Index = (int)(LaneDelta + (Self & (Width - 1))) >= Width ? Self : Index; 99 return __builtin_amdgcn_ds_bpermute(Index << 2, Var); 100 } 101 102 #pragma omp end declare variant 103 ///} 104 105 /// NVPTX Implementation 106 /// 107 ///{ 108 #pragma omp begin declare variant match( \ 109 device = {arch(nvptx, nvptx64)}, implementation = {extension(match_any)}) 110 111 int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane) { 112 return __nvvm_shfl_sync_idx_i32(Mask, Var, SrcLane, 0x1f); 113 } 114 115 int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width) { 116 int32_t T = ((mapping::getWarpSize() - Width) << 8) | 0x1f; 117 return __nvvm_shfl_sync_down_i32(Mask, Var, Delta, T); 118 } 119 120 #pragma omp end declare variant 121 } // namespace impl 122 123 uint64_t utils::pack(uint32_t LowBits, uint32_t HighBits) { 124 return impl::Pack(LowBits, HighBits); 125 } 126 127 void utils::unpack(uint64_t Val, uint32_t &LowBits, uint32_t &HighBits) { 128 impl::Unpack(Val, &LowBits, &HighBits); 129 } 130 131 int32_t utils::shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane) { 132 return impl::shuffle(Mask, Var, SrcLane); 133 } 134 135 int32_t utils::shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, 136 int32_t Width) { 137 return impl::shuffleDown(Mask, Var, Delta, Width); 138 } 139 140 extern "C" { 141 int32_t __kmpc_shuffle_int32(int32_t Val, int16_t Delta, int16_t SrcLane) { 142 FunctionTracingRAII(); 143 return impl::shuffleDown(lanes::All, Val, Delta, SrcLane); 144 } 145 146 int64_t __kmpc_shuffle_int64(int64_t Val, int16_t Delta, int16_t Width) { 147 FunctionTracingRAII(); 148 uint32_t lo, hi; 149 utils::unpack(Val, lo, hi); 150 hi = impl::shuffleDown(lanes::All, hi, Delta, Width); 151 lo = impl::shuffleDown(lanes::All, lo, Delta, Width); 152 return utils::pack(lo, hi); 153 } 154 } 155 156 #pragma omp end declare target 157