1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/STLForwardCompat.h"
18 #include "llvm/ADT/StringSwitch.h"
19 #include "llvm/BinaryFormat/MsgPackDocument.h"
20
21 #include <map>
22 #include <utility>
23
24 namespace llvm {
25 namespace AMDGPU {
26 namespace HSAMD {
27 namespace V3 {
28
verifyScalar(msgpack::DocNode & Node,msgpack::Type SKind,function_ref<bool (msgpack::DocNode &)> verifyValue)29 bool MetadataVerifier::verifyScalar(
30 msgpack::DocNode &Node, msgpack::Type SKind,
31 function_ref<bool(msgpack::DocNode &)> verifyValue) {
32 if (!Node.isScalar())
33 return false;
34 if (Node.getKind() != SKind) {
35 if (Strict)
36 return false;
37 // If we are not strict, we interpret string values as "implicitly typed"
38 // and attempt to coerce them to the expected type here.
39 if (Node.getKind() != msgpack::Type::String)
40 return false;
41 StringRef StringValue = Node.getString();
42 Node.fromString(StringValue);
43 if (Node.getKind() != SKind)
44 return false;
45 }
46 if (verifyValue)
47 return verifyValue(Node);
48 return true;
49 }
50
verifyInteger(msgpack::DocNode & Node)51 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
52 if (!verifyScalar(Node, msgpack::Type::UInt))
53 if (!verifyScalar(Node, msgpack::Type::Int))
54 return false;
55 return true;
56 }
57
verifyArray(msgpack::DocNode & Node,function_ref<bool (msgpack::DocNode &)> verifyNode,Optional<size_t> Size)58 bool MetadataVerifier::verifyArray(
59 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
60 Optional<size_t> Size) {
61 if (!Node.isArray())
62 return false;
63 auto &Array = Node.getArray();
64 if (Size && Array.size() != *Size)
65 return false;
66 return llvm::all_of(Array, verifyNode);
67 }
68
verifyEntry(msgpack::MapDocNode & MapNode,StringRef Key,bool Required,function_ref<bool (msgpack::DocNode &)> verifyNode)69 bool MetadataVerifier::verifyEntry(
70 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
71 function_ref<bool(msgpack::DocNode &)> verifyNode) {
72 auto Entry = MapNode.find(Key);
73 if (Entry == MapNode.end())
74 return !Required;
75 return verifyNode(Entry->second);
76 }
77
verifyScalarEntry(msgpack::MapDocNode & MapNode,StringRef Key,bool Required,msgpack::Type SKind,function_ref<bool (msgpack::DocNode &)> verifyValue)78 bool MetadataVerifier::verifyScalarEntry(
79 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
80 msgpack::Type SKind,
81 function_ref<bool(msgpack::DocNode &)> verifyValue) {
82 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
83 return verifyScalar(Node, SKind, verifyValue);
84 });
85 }
86
verifyIntegerEntry(msgpack::MapDocNode & MapNode,StringRef Key,bool Required)87 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
88 StringRef Key, bool Required) {
89 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
90 return verifyInteger(Node);
91 });
92 }
93
verifyKernelArgs(msgpack::DocNode & Node)94 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
95 if (!Node.isMap())
96 return false;
97 auto &ArgsMap = Node.getMap();
98
99 if (!verifyScalarEntry(ArgsMap, ".name", false,
100 msgpack::Type::String))
101 return false;
102 if (!verifyScalarEntry(ArgsMap, ".type_name", false,
103 msgpack::Type::String))
104 return false;
105 if (!verifyIntegerEntry(ArgsMap, ".size", true))
106 return false;
107 if (!verifyIntegerEntry(ArgsMap, ".offset", true))
108 return false;
109 if (!verifyScalarEntry(ArgsMap, ".value_kind", true, msgpack::Type::String,
110 [](msgpack::DocNode &SNode) {
111 return StringSwitch<bool>(SNode.getString())
112 .Case("by_value", true)
113 .Case("global_buffer", true)
114 .Case("dynamic_shared_pointer", true)
115 .Case("sampler", true)
116 .Case("image", true)
117 .Case("pipe", true)
118 .Case("queue", true)
119 .Case("hidden_block_count_x", true)
120 .Case("hidden_block_count_y", true)
121 .Case("hidden_block_count_z", true)
122 .Case("hidden_group_size_x", true)
123 .Case("hidden_group_size_y", true)
124 .Case("hidden_group_size_z", true)
125 .Case("hidden_remainder_x", true)
126 .Case("hidden_remainder_y", true)
127 .Case("hidden_remainder_z", true)
128 .Case("hidden_global_offset_x", true)
129 .Case("hidden_global_offset_y", true)
130 .Case("hidden_global_offset_z", true)
131 .Case("hidden_grid_dims", true)
132 .Case("hidden_none", true)
133 .Case("hidden_printf_buffer", true)
134 .Case("hidden_hostcall_buffer", true)
135 .Case("hidden_heap_v1", true)
136 .Case("hidden_default_queue", true)
137 .Case("hidden_completion_action", true)
138 .Case("hidden_multigrid_sync_arg", true)
139 .Case("hidden_private_base", true)
140 .Case("hidden_shared_base", true)
141 .Case("hidden_queue_ptr", true)
142 .Default(false);
143 }))
144 return false;
145 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
146 return false;
147 if (!verifyScalarEntry(ArgsMap, ".address_space", false,
148 msgpack::Type::String,
149 [](msgpack::DocNode &SNode) {
150 return StringSwitch<bool>(SNode.getString())
151 .Case("private", true)
152 .Case("global", true)
153 .Case("constant", true)
154 .Case("local", true)
155 .Case("generic", true)
156 .Case("region", true)
157 .Default(false);
158 }))
159 return false;
160 if (!verifyScalarEntry(ArgsMap, ".access", false,
161 msgpack::Type::String,
162 [](msgpack::DocNode &SNode) {
163 return StringSwitch<bool>(SNode.getString())
164 .Case("read_only", true)
165 .Case("write_only", true)
166 .Case("read_write", true)
167 .Default(false);
168 }))
169 return false;
170 if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
171 msgpack::Type::String,
172 [](msgpack::DocNode &SNode) {
173 return StringSwitch<bool>(SNode.getString())
174 .Case("read_only", true)
175 .Case("write_only", true)
176 .Case("read_write", true)
177 .Default(false);
178 }))
179 return false;
180 if (!verifyScalarEntry(ArgsMap, ".is_const", false,
181 msgpack::Type::Boolean))
182 return false;
183 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
184 msgpack::Type::Boolean))
185 return false;
186 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
187 msgpack::Type::Boolean))
188 return false;
189 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
190 msgpack::Type::Boolean))
191 return false;
192
193 return true;
194 }
195
verifyKernel(msgpack::DocNode & Node)196 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
197 if (!Node.isMap())
198 return false;
199 auto &KernelMap = Node.getMap();
200
201 if (!verifyScalarEntry(KernelMap, ".name", true,
202 msgpack::Type::String))
203 return false;
204 if (!verifyScalarEntry(KernelMap, ".symbol", true,
205 msgpack::Type::String))
206 return false;
207 if (!verifyScalarEntry(KernelMap, ".language", false,
208 msgpack::Type::String,
209 [](msgpack::DocNode &SNode) {
210 return StringSwitch<bool>(SNode.getString())
211 .Case("OpenCL C", true)
212 .Case("OpenCL C++", true)
213 .Case("HCC", true)
214 .Case("HIP", true)
215 .Case("OpenMP", true)
216 .Case("Assembler", true)
217 .Default(false);
218 }))
219 return false;
220 if (!verifyEntry(
221 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
222 return verifyArray(
223 Node,
224 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
225 }))
226 return false;
227 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
228 return verifyArray(Node, [this](msgpack::DocNode &Node) {
229 return verifyKernelArgs(Node);
230 });
231 }))
232 return false;
233 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
234 [this](msgpack::DocNode &Node) {
235 return verifyArray(Node,
236 [this](msgpack::DocNode &Node) {
237 return verifyInteger(Node);
238 },
239 3);
240 }))
241 return false;
242 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
243 [this](msgpack::DocNode &Node) {
244 return verifyArray(Node,
245 [this](msgpack::DocNode &Node) {
246 return verifyInteger(Node);
247 },
248 3);
249 }))
250 return false;
251 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
252 msgpack::Type::String))
253 return false;
254 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
255 msgpack::Type::String))
256 return false;
257 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
258 return false;
259 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
260 return false;
261 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
262 return false;
263 if (!verifyScalarEntry(KernelMap, ".uses_dynamic_stack", false,
264 msgpack::Type::Boolean))
265 return false;
266 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
267 return false;
268 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
269 return false;
270 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
271 return false;
272 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
273 return false;
274 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
275 return false;
276 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
277 return false;
278 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
279 return false;
280
281 return true;
282 }
283
verify(msgpack::DocNode & HSAMetadataRoot)284 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
285 if (!HSAMetadataRoot.isMap())
286 return false;
287 auto &RootMap = HSAMetadataRoot.getMap();
288
289 if (!verifyEntry(
290 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
291 return verifyArray(
292 Node,
293 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
294 }))
295 return false;
296 if (!verifyEntry(
297 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
298 return verifyArray(Node, [this](msgpack::DocNode &Node) {
299 return verifyScalar(Node, msgpack::Type::String);
300 });
301 }))
302 return false;
303 if (!verifyEntry(RootMap, "amdhsa.kernels", true,
304 [this](msgpack::DocNode &Node) {
305 return verifyArray(Node, [this](msgpack::DocNode &Node) {
306 return verifyKernel(Node);
307 });
308 }))
309 return false;
310
311 return true;
312 }
313
314 } // end namespace V3
315 } // end namespace HSAMD
316 } // end namespace AMDGPU
317 } // end namespace llvm
318