1 /*===--------------------------------------------------------------------------
2  *              ATMI (Asynchronous Task and Memory Interface)
3  *
4  * This file is distributed under the MIT License. See LICENSE.txt for details.
5  *===------------------------------------------------------------------------*/
6 #include "data.h"
7 #include "atmi_runtime.h"
8 #include "internal.h"
9 #include "machine.h"
10 #include "rt.h"
11 #include <cassert>
12 #include <hsa.h>
13 #include <hsa_ext_amd.h>
14 #include <iostream>
15 #include <stdio.h>
16 #include <string.h>
17 #include <thread>
18 #include <vector>
19 
20 using core::TaskImpl;
21 extern ATLMachine g_atl_machine;
22 
23 namespace core {
24 ATLPointerTracker g_data_map; // Track all am pointer allocations.
25 void allow_access_to_all_gpu_agents(void *ptr);
26 
27 const char *getPlaceStr(atmi_devtype_t type) {
28   switch (type) {
29   case ATMI_DEVTYPE_CPU:
30     return "CPU";
31   case ATMI_DEVTYPE_GPU:
32     return "GPU";
33   default:
34     return NULL;
35   }
36 }
37 
38 std::ostream &operator<<(std::ostream &os, const ATLData *ap) {
39   atmi_mem_place_t place = ap->place();
40   os << " devicePointer:" << ap->ptr() << " sizeBytes:" << ap->size()
41      << " place:(" << getPlaceStr(place.dev_type) << ", " << place.dev_id
42      << ", " << place.mem_id << ")";
43   return os;
44 }
45 
46 void ATLPointerTracker::insert(void *pointer, ATLData *p) {
47   std::lock_guard<std::mutex> l(mutex_);
48 
49   DEBUG_PRINT("insert: %p + %zu\n", pointer, p->size());
50   tracker_.insert(std::make_pair(ATLMemoryRange(pointer, p->size()), p));
51 }
52 
53 void ATLPointerTracker::remove(void *pointer) {
54   std::lock_guard<std::mutex> l(mutex_);
55   DEBUG_PRINT("remove: %p\n", pointer);
56   tracker_.erase(ATLMemoryRange(pointer, 1));
57 }
58 
59 ATLData *ATLPointerTracker::find(const void *pointer) {
60   std::lock_guard<std::mutex> l(mutex_);
61   ATLData *ret = NULL;
62   auto iter = tracker_.find(ATLMemoryRange(pointer, 1));
63   DEBUG_PRINT("find: %p\n", pointer);
64   if (iter != tracker_.end()) // found
65     ret = iter->second;
66   return ret;
67 }
68 
69 ATLProcessor &get_processor_by_mem_place(atmi_mem_place_t place) {
70   int dev_id = place.dev_id;
71   switch (place.dev_type) {
72   case ATMI_DEVTYPE_CPU:
73     return g_atl_machine.processors<ATLCPUProcessor>()[dev_id];
74   case ATMI_DEVTYPE_GPU:
75     return g_atl_machine.processors<ATLGPUProcessor>()[dev_id];
76   }
77 }
78 
79 static hsa_agent_t get_mem_agent(atmi_mem_place_t place) {
80   return get_processor_by_mem_place(place).agent();
81 }
82 
83 hsa_amd_memory_pool_t get_memory_pool_by_mem_place(atmi_mem_place_t place) {
84   ATLProcessor &proc = get_processor_by_mem_place(place);
85   return get_memory_pool(proc, place.mem_id);
86 }
87 
88 void register_allocation(void *ptr, size_t size, atmi_mem_place_t place) {
89   ATLData *data = new ATLData(ptr, size, place);
90   g_data_map.insert(ptr, data);
91   if (place.dev_type == ATMI_DEVTYPE_CPU)
92     allow_access_to_all_gpu_agents(ptr);
93   // TODO(ashwinma): what if one GPU wants to access another GPU?
94 }
95 
96 atmi_status_t Runtime::Malloc(void **ptr, size_t size, atmi_mem_place_t place) {
97   atmi_status_t ret = ATMI_STATUS_SUCCESS;
98   hsa_amd_memory_pool_t pool = get_memory_pool_by_mem_place(place);
99   hsa_status_t err = hsa_amd_memory_pool_allocate(pool, size, 0, ptr);
100   ErrorCheck(atmi_malloc, err);
101   DEBUG_PRINT("Malloced [%s %d] %p\n",
102               place.dev_type == ATMI_DEVTYPE_CPU ? "CPU" : "GPU", place.dev_id,
103               *ptr);
104   if (err != HSA_STATUS_SUCCESS)
105     ret = ATMI_STATUS_ERROR;
106 
107   register_allocation(*ptr, size, place);
108 
109   return ret;
110 }
111 
112 atmi_status_t Runtime::Memfree(void *ptr) {
113   atmi_status_t ret = ATMI_STATUS_SUCCESS;
114   hsa_status_t err;
115   ATLData *data = g_data_map.find(ptr);
116   if (!data)
117     ErrorCheck(Checking pointer info userData,
118                HSA_STATUS_ERROR_INVALID_ALLOCATION);
119 
120   g_data_map.remove(ptr);
121   delete data;
122 
123   err = hsa_amd_memory_pool_free(ptr);
124   ErrorCheck(atmi_free, err);
125   DEBUG_PRINT("Freed %p\n", ptr);
126 
127   if (err != HSA_STATUS_SUCCESS || !data)
128     ret = ATMI_STATUS_ERROR;
129   return ret;
130 }
131 
132 static hsa_status_t invoke_hsa_copy(hsa_signal_t sig, void *dest,
133                                     const void *src, size_t size,
134                                     hsa_agent_t agent) {
135   const hsa_signal_value_t init = 1;
136   const hsa_signal_value_t success = 0;
137   hsa_signal_store_screlease(sig, init);
138 
139   hsa_status_t err =
140       hsa_amd_memory_async_copy(dest, agent, src, agent, size, 0, NULL, sig);
141   if (err != HSA_STATUS_SUCCESS) {
142     return err;
143   }
144 
145   // async_copy reports success by decrementing and failure by setting to < 0
146   hsa_signal_value_t got = init;
147   while (got == init) {
148     got = hsa_signal_wait_scacquire(sig, HSA_SIGNAL_CONDITION_NE, init,
149                                     UINT64_MAX, ATMI_WAIT_STATE);
150   }
151 
152   if (got != success) {
153     return HSA_STATUS_ERROR;
154   }
155 
156   return err;
157 }
158 
159 struct atmiFreePtrDeletor {
160   void operator()(void *p) {
161     atmi_free(p); // ignore failure to free
162   }
163 };
164 
165 atmi_status_t Runtime::Memcpy(hsa_signal_t sig, void *dest, const void *src,
166                               size_t size) {
167   ATLData *src_data = g_data_map.find(src);
168   ATLData *dest_data = g_data_map.find(dest);
169   atmi_mem_place_t cpu = ATMI_MEM_PLACE_CPU_MEM(0, 0, 0);
170 
171   void *temp_host_ptr;
172   atmi_status_t ret = atmi_malloc(&temp_host_ptr, size, cpu);
173   if (ret != ATMI_STATUS_SUCCESS) {
174     return ret;
175   }
176   std::unique_ptr<void, atmiFreePtrDeletor> del(temp_host_ptr);
177 
178   if (src_data && !dest_data) {
179     // Copy from device to scratch to host
180     hsa_agent_t agent = get_mem_agent(src_data->place());
181     DEBUG_PRINT("Memcpy D2H device agent: %lu\n", agent.handle);
182 
183     if (invoke_hsa_copy(sig, temp_host_ptr, src, size, agent) !=
184         HSA_STATUS_SUCCESS) {
185       return ATMI_STATUS_ERROR;
186     }
187 
188     memcpy(dest, temp_host_ptr, size);
189 
190   } else if (!src_data && dest_data) {
191     // Copy from host to scratch to device
192     hsa_agent_t agent = get_mem_agent(dest_data->place());
193     DEBUG_PRINT("Memcpy H2D device agent: %lu\n", agent.handle);
194 
195     memcpy(temp_host_ptr, src, size);
196 
197     if (invoke_hsa_copy(sig, dest, temp_host_ptr, size, agent) !=
198         HSA_STATUS_SUCCESS) {
199       return ATMI_STATUS_ERROR;
200     }
201 
202   } else if (!src_data && !dest_data) {
203     // would be host to host, just call memcpy, or missing metadata
204     DEBUG_PRINT("atmi_memcpy invoked without metadata\n");
205     return ATMI_STATUS_ERROR;
206   } else {
207     DEBUG_PRINT("atmi_memcpy unimplemented device to device copy\n");
208     return ATMI_STATUS_ERROR;
209   }
210 
211   return ATMI_STATUS_SUCCESS;
212 }
213 
214 } // namespace core
215