1 /*===-------------------------------------------------------------------------- 2 * ATMI (Asynchronous Task and Memory Interface) 3 * 4 * This file is distributed under the MIT License. See LICENSE.txt for details. 5 *===------------------------------------------------------------------------*/ 6 #include "data.h" 7 #include "atmi_runtime.h" 8 #include "internal.h" 9 #include "machine.h" 10 #include "rt.h" 11 #include <cassert> 12 #include <hsa.h> 13 #include <hsa_ext_amd.h> 14 #include <iostream> 15 #include <stdio.h> 16 #include <string.h> 17 #include <thread> 18 #include <vector> 19 20 using core::TaskImpl; 21 extern ATLMachine g_atl_machine; 22 extern hsa_signal_t IdentityCopySignal; 23 24 namespace core { 25 ATLPointerTracker g_data_map; // Track all am pointer allocations. 26 void allow_access_to_all_gpu_agents(void *ptr); 27 28 const char *getPlaceStr(atmi_devtype_t type) { 29 switch (type) { 30 case ATMI_DEVTYPE_CPU: 31 return "CPU"; 32 case ATMI_DEVTYPE_GPU: 33 return "GPU"; 34 default: 35 return NULL; 36 } 37 } 38 39 std::ostream &operator<<(std::ostream &os, const ATLData *ap) { 40 atmi_mem_place_t place = ap->place(); 41 os << " devicePointer:" << ap->ptr() << " sizeBytes:" << ap->size() 42 << " place:(" << getPlaceStr(place.dev_type) << ", " << place.dev_id 43 << ", " << place.mem_id << ")"; 44 return os; 45 } 46 47 void ATLPointerTracker::insert(void *pointer, ATLData *p) { 48 std::lock_guard<std::mutex> l(mutex_); 49 50 DEBUG_PRINT("insert: %p + %zu\n", pointer, p->size()); 51 tracker_.insert(std::make_pair(ATLMemoryRange(pointer, p->size()), p)); 52 } 53 54 void ATLPointerTracker::remove(void *pointer) { 55 std::lock_guard<std::mutex> l(mutex_); 56 DEBUG_PRINT("remove: %p\n", pointer); 57 tracker_.erase(ATLMemoryRange(pointer, 1)); 58 } 59 60 ATLData *ATLPointerTracker::find(const void *pointer) { 61 std::lock_guard<std::mutex> l(mutex_); 62 ATLData *ret = NULL; 63 auto iter = tracker_.find(ATLMemoryRange(pointer, 1)); 64 DEBUG_PRINT("find: %p\n", pointer); 65 if (iter != tracker_.end()) // found 66 ret = iter->second; 67 return ret; 68 } 69 70 ATLProcessor &get_processor_by_mem_place(atmi_mem_place_t place) { 71 int dev_id = place.dev_id; 72 switch (place.dev_type) { 73 case ATMI_DEVTYPE_CPU: 74 return g_atl_machine.processors<ATLCPUProcessor>()[dev_id]; 75 case ATMI_DEVTYPE_GPU: 76 return g_atl_machine.processors<ATLGPUProcessor>()[dev_id]; 77 } 78 } 79 80 static hsa_agent_t get_mem_agent(atmi_mem_place_t place) { 81 return get_processor_by_mem_place(place).agent(); 82 } 83 84 hsa_amd_memory_pool_t get_memory_pool_by_mem_place(atmi_mem_place_t place) { 85 ATLProcessor &proc = get_processor_by_mem_place(place); 86 return get_memory_pool(proc, place.mem_id); 87 } 88 89 void register_allocation(void *ptr, size_t size, atmi_mem_place_t place) { 90 ATLData *data = new ATLData(ptr, size, place); 91 g_data_map.insert(ptr, data); 92 if (place.dev_type == ATMI_DEVTYPE_CPU) 93 allow_access_to_all_gpu_agents(ptr); 94 // TODO(ashwinma): what if one GPU wants to access another GPU? 95 } 96 97 atmi_status_t Runtime::Malloc(void **ptr, size_t size, atmi_mem_place_t place) { 98 atmi_status_t ret = ATMI_STATUS_SUCCESS; 99 hsa_amd_memory_pool_t pool = get_memory_pool_by_mem_place(place); 100 hsa_status_t err = hsa_amd_memory_pool_allocate(pool, size, 0, ptr); 101 ErrorCheck(atmi_malloc, err); 102 DEBUG_PRINT("Malloced [%s %d] %p\n", 103 place.dev_type == ATMI_DEVTYPE_CPU ? "CPU" : "GPU", place.dev_id, 104 *ptr); 105 if (err != HSA_STATUS_SUCCESS) 106 ret = ATMI_STATUS_ERROR; 107 108 register_allocation(*ptr, size, place); 109 110 return ret; 111 } 112 113 atmi_status_t Runtime::Memfree(void *ptr) { 114 atmi_status_t ret = ATMI_STATUS_SUCCESS; 115 hsa_status_t err; 116 ATLData *data = g_data_map.find(ptr); 117 if (!data) 118 ErrorCheck(Checking pointer info userData, 119 HSA_STATUS_ERROR_INVALID_ALLOCATION); 120 121 g_data_map.remove(ptr); 122 delete data; 123 124 err = hsa_amd_memory_pool_free(ptr); 125 ErrorCheck(atmi_free, err); 126 DEBUG_PRINT("Freed %p\n", ptr); 127 128 if (err != HSA_STATUS_SUCCESS || !data) 129 ret = ATMI_STATUS_ERROR; 130 return ret; 131 } 132 133 static hsa_status_t invoke_hsa_copy(void *dest, const void *src, size_t size, 134 hsa_agent_t agent) { 135 // TODO: Use thread safe signal 136 hsa_signal_store_screlease(IdentityCopySignal, 1); 137 138 hsa_status_t err = hsa_amd_memory_async_copy(dest, agent, src, agent, size, 0, 139 NULL, IdentityCopySignal); 140 ErrorCheck(Copy async between memory pools, err); 141 142 // TODO: async reports errors in the signal, use NE 1 143 hsa_signal_wait_scacquire(IdentityCopySignal, HSA_SIGNAL_CONDITION_EQ, 0, 144 UINT64_MAX, ATMI_WAIT_STATE); 145 146 return err; 147 } 148 149 atmi_status_t Runtime::Memcpy(void *dest, const void *src, size_t size) { 150 atmi_status_t ret; 151 hsa_status_t err; 152 ATLData *src_data = g_data_map.find(src); 153 ATLData *dest_data = g_data_map.find(dest); 154 atmi_mem_place_t cpu = ATMI_MEM_PLACE_CPU_MEM(0, 0, 0); 155 void *temp_host_ptr; 156 157 if (src_data && !dest_data) { 158 // Copy from device to scratch to host 159 hsa_agent_t agent = get_mem_agent(src_data->place()); 160 DEBUG_PRINT("Memcpy D2H device agent: %lu\n", agent.handle); 161 ret = atmi_malloc(&temp_host_ptr, size, cpu); 162 if (ret != ATMI_STATUS_SUCCESS) { 163 return ret; 164 } 165 166 err = invoke_hsa_copy(temp_host_ptr, src, size, agent); 167 if (err != HSA_STATUS_SUCCESS) { 168 return ATMI_STATUS_ERROR; 169 } 170 171 memcpy(dest, temp_host_ptr, size); 172 173 } else if (!src_data && dest_data) { 174 // Copy from host to scratch to device 175 hsa_agent_t agent = get_mem_agent(dest_data->place()); 176 DEBUG_PRINT("Memcpy H2D device agent: %lu\n", agent.handle); 177 ret = atmi_malloc(&temp_host_ptr, size, cpu); 178 if (ret != ATMI_STATUS_SUCCESS) { 179 return ret; 180 } 181 182 memcpy(temp_host_ptr, src, size); 183 184 DEBUG_PRINT("Memcpy device agent: %lu\n", agent.handle); 185 err = invoke_hsa_copy(dest, temp_host_ptr, size, agent); 186 187 } else if (!src_data && !dest_data) { 188 DEBUG_PRINT("atmi_memcpy invoked without metadata\n"); 189 // would be host to host, just call memcpy, or missing metadata 190 return ATMI_STATUS_ERROR; 191 } else { 192 DEBUG_PRINT("atmi_memcpy unimplemented device to device copy\n"); 193 return ATMI_STATUS_ERROR; 194 } 195 196 ret = atmi_free(temp_host_ptr); 197 198 if (err != HSA_STATUS_SUCCESS || ret != ATMI_STATUS_SUCCESS) 199 ret = ATMI_STATUS_ERROR; 200 return ret; 201 } 202 203 } // namespace core 204