1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include "llvm/Support/AMDGPUMetadata.h"
17 
18 namespace llvm {
19 namespace AMDGPU {
20 namespace HSAMD {
21 namespace V3 {
22 
23 bool MetadataVerifier::verifyScalar(
24     msgpack::DocNode &Node, msgpack::Type SKind,
25     function_ref<bool(msgpack::DocNode &)> verifyValue) {
26   if (!Node.isScalar())
27     return false;
28   if (Node.getKind() != SKind) {
29     if (Strict)
30       return false;
31     // If we are not strict, we interpret string values as "implicitly typed"
32     // and attempt to coerce them to the expected type here.
33     if (Node.getKind() != msgpack::Type::String)
34       return false;
35     StringRef StringValue = Node.getString();
36     Node.fromString(StringValue);
37     if (Node.getKind() != SKind)
38       return false;
39   }
40   if (verifyValue)
41     return verifyValue(Node);
42   return true;
43 }
44 
45 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
46   if (!verifyScalar(Node, msgpack::Type::UInt))
47     if (!verifyScalar(Node, msgpack::Type::Int))
48       return false;
49   return true;
50 }
51 
52 bool MetadataVerifier::verifyArray(
53     msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
54     Optional<size_t> Size) {
55   if (!Node.isArray())
56     return false;
57   auto &Array = Node.getArray();
58   if (Size && Array.size() != *Size)
59     return false;
60   for (auto &Item : Array)
61     if (!verifyNode(Item))
62       return false;
63 
64   return true;
65 }
66 
67 bool MetadataVerifier::verifyEntry(
68     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
69     function_ref<bool(msgpack::DocNode &)> verifyNode) {
70   auto Entry = MapNode.find(Key);
71   if (Entry == MapNode.end())
72     return !Required;
73   return verifyNode(Entry->second);
74 }
75 
76 bool MetadataVerifier::verifyScalarEntry(
77     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
78     msgpack::Type SKind,
79     function_ref<bool(msgpack::DocNode &)> verifyValue) {
80   return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
81     return verifyScalar(Node, SKind, verifyValue);
82   });
83 }
84 
85 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
86                                           StringRef Key, bool Required) {
87   return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
88     return verifyInteger(Node);
89   });
90 }
91 
92 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
93   if (!Node.isMap())
94     return false;
95   auto &ArgsMap = Node.getMap();
96 
97   if (!verifyScalarEntry(ArgsMap, ".name", false,
98                          msgpack::Type::String))
99     return false;
100   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
101                          msgpack::Type::String))
102     return false;
103   if (!verifyIntegerEntry(ArgsMap, ".size", true))
104     return false;
105   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
106     return false;
107   if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
108                          msgpack::Type::String,
109                          [](msgpack::DocNode &SNode) {
110                            return StringSwitch<bool>(SNode.getString())
111                                .Case("by_value", true)
112                                .Case("global_buffer", true)
113                                .Case("dynamic_shared_pointer", true)
114                                .Case("sampler", true)
115                                .Case("image", true)
116                                .Case("pipe", true)
117                                .Case("queue", true)
118                                .Case("hidden_global_offset_x", true)
119                                .Case("hidden_global_offset_y", true)
120                                .Case("hidden_global_offset_z", true)
121                                .Case("hidden_none", true)
122                                .Case("hidden_printf_buffer", true)
123                                .Case("hidden_hostcall_buffer", true)
124                                .Case("hidden_default_queue", true)
125                                .Case("hidden_completion_action", true)
126                                .Case("hidden_multigrid_sync_arg", true)
127                                .Default(false);
128                          }))
129     return false;
130   if (!verifyScalarEntry(ArgsMap, ".value_type", true,
131                          msgpack::Type::String,
132                          [](msgpack::DocNode &SNode) {
133                            return StringSwitch<bool>(SNode.getString())
134                                .Case("struct", true)
135                                .Case("i8", true)
136                                .Case("u8", true)
137                                .Case("i16", true)
138                                .Case("u16", true)
139                                .Case("f16", true)
140                                .Case("i32", true)
141                                .Case("u32", true)
142                                .Case("f32", true)
143                                .Case("i64", true)
144                                .Case("u64", true)
145                                .Case("f64", true)
146                                .Default(false);
147                          }))
148     return false;
149   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
150     return false;
151   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
152                          msgpack::Type::String,
153                          [](msgpack::DocNode &SNode) {
154                            return StringSwitch<bool>(SNode.getString())
155                                .Case("private", true)
156                                .Case("global", true)
157                                .Case("constant", true)
158                                .Case("local", true)
159                                .Case("generic", true)
160                                .Case("region", true)
161                                .Default(false);
162                          }))
163     return false;
164   if (!verifyScalarEntry(ArgsMap, ".access", false,
165                          msgpack::Type::String,
166                          [](msgpack::DocNode &SNode) {
167                            return StringSwitch<bool>(SNode.getString())
168                                .Case("read_only", true)
169                                .Case("write_only", true)
170                                .Case("read_write", true)
171                                .Default(false);
172                          }))
173     return false;
174   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
175                          msgpack::Type::String,
176                          [](msgpack::DocNode &SNode) {
177                            return StringSwitch<bool>(SNode.getString())
178                                .Case("read_only", true)
179                                .Case("write_only", true)
180                                .Case("read_write", true)
181                                .Default(false);
182                          }))
183     return false;
184   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
185                          msgpack::Type::Boolean))
186     return false;
187   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
188                          msgpack::Type::Boolean))
189     return false;
190   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
191                          msgpack::Type::Boolean))
192     return false;
193   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
194                          msgpack::Type::Boolean))
195     return false;
196 
197   return true;
198 }
199 
200 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
201   if (!Node.isMap())
202     return false;
203   auto &KernelMap = Node.getMap();
204 
205   if (!verifyScalarEntry(KernelMap, ".name", true,
206                          msgpack::Type::String))
207     return false;
208   if (!verifyScalarEntry(KernelMap, ".symbol", true,
209                          msgpack::Type::String))
210     return false;
211   if (!verifyScalarEntry(KernelMap, ".language", false,
212                          msgpack::Type::String,
213                          [](msgpack::DocNode &SNode) {
214                            return StringSwitch<bool>(SNode.getString())
215                                .Case("OpenCL C", true)
216                                .Case("OpenCL C++", true)
217                                .Case("HCC", true)
218                                .Case("HIP", true)
219                                .Case("OpenMP", true)
220                                .Case("Assembler", true)
221                                .Default(false);
222                          }))
223     return false;
224   if (!verifyEntry(
225           KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
226             return verifyArray(
227                 Node,
228                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
229           }))
230     return false;
231   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
232         return verifyArray(Node, [this](msgpack::DocNode &Node) {
233           return verifyKernelArgs(Node);
234         });
235       }))
236     return false;
237   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
238                    [this](msgpack::DocNode &Node) {
239                      return verifyArray(Node,
240                                         [this](msgpack::DocNode &Node) {
241                                           return verifyInteger(Node);
242                                         },
243                                         3);
244                    }))
245     return false;
246   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
247                    [this](msgpack::DocNode &Node) {
248                      return verifyArray(Node,
249                                         [this](msgpack::DocNode &Node) {
250                                           return verifyInteger(Node);
251                                         },
252                                         3);
253                    }))
254     return false;
255   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
256                          msgpack::Type::String))
257     return false;
258   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
259                          msgpack::Type::String))
260     return false;
261   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
262     return false;
263   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
264     return false;
265   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
266     return false;
267   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
268     return false;
269   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
270     return false;
271   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
272     return false;
273   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
274     return false;
275   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
276     return false;
277   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
278     return false;
279   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
280     return false;
281 
282   return true;
283 }
284 
285 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
286   if (!HSAMetadataRoot.isMap())
287     return false;
288   auto &RootMap = HSAMetadataRoot.getMap();
289 
290   if (!verifyEntry(
291           RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
292             return verifyArray(
293                 Node,
294                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
295           }))
296     return false;
297   if (!verifyEntry(
298           RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
299             return verifyArray(Node, [this](msgpack::DocNode &Node) {
300               return verifyScalar(Node, msgpack::Type::String);
301             });
302           }))
303     return false;
304   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
305                    [this](msgpack::DocNode &Node) {
306                      return verifyArray(Node, [this](msgpack::DocNode &Node) {
307                        return verifyKernel(Node);
308                      });
309                    }))
310     return false;
311 
312   return true;
313 }
314 
315 } // end namespace V3
316 } // end namespace HSAMD
317 } // end namespace AMDGPU
318 } // end namespace llvm
319