1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// Implements a verifier for AMDGPU HSA metadata.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
16 #include "llvm/Support/AMDGPUMetadata.h"
17 
18 namespace llvm {
19 namespace AMDGPU {
20 namespace HSAMD {
21 namespace V3 {
22 
23 bool MetadataVerifier::verifyScalar(
24     msgpack::Node &Node, msgpack::ScalarNode::ScalarKind SKind,
25     function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
26   auto ScalarPtr = dyn_cast<msgpack::ScalarNode>(&Node);
27   if (!ScalarPtr)
28     return false;
29   auto &Scalar = *ScalarPtr;
30   // Do not output extraneous tags for types we know from the spec.
31   Scalar.IgnoreTag = true;
32   if (Scalar.getScalarKind() != SKind) {
33     if (Strict)
34       return false;
35     // If we are not strict, we interpret string values as "implicitly typed"
36     // and attempt to coerce them to the expected type here.
37     if (Scalar.getScalarKind() != msgpack::ScalarNode::SK_String)
38       return false;
39     std::string StringValue = Scalar.getString();
40     Scalar.setScalarKind(SKind);
41     if (Scalar.inputYAML(StringValue) != StringRef())
42       return false;
43   }
44   if (verifyValue)
45     return verifyValue(Scalar);
46   return true;
47 }
48 
49 bool MetadataVerifier::verifyInteger(msgpack::Node &Node) {
50   if (!verifyScalar(Node, msgpack::ScalarNode::SK_UInt))
51     if (!verifyScalar(Node, msgpack::ScalarNode::SK_Int))
52       return false;
53   return true;
54 }
55 
56 bool MetadataVerifier::verifyArray(
57     msgpack::Node &Node, function_ref<bool(msgpack::Node &)> verifyNode,
58     Optional<size_t> Size) {
59   auto ArrayPtr = dyn_cast<msgpack::ArrayNode>(&Node);
60   if (!ArrayPtr)
61     return false;
62   auto &Array = *ArrayPtr;
63   if (Size && Array.size() != *Size)
64     return false;
65   for (auto &Item : Array)
66     if (!verifyNode(*Item.get()))
67       return false;
68 
69   return true;
70 }
71 
72 bool MetadataVerifier::verifyEntry(
73     msgpack::MapNode &MapNode, StringRef Key, bool Required,
74     function_ref<bool(msgpack::Node &)> verifyNode) {
75   auto Entry = MapNode.find(Key);
76   if (Entry == MapNode.end())
77     return !Required;
78   return verifyNode(*Entry->second.get());
79 }
80 
81 bool MetadataVerifier::verifyScalarEntry(
82     msgpack::MapNode &MapNode, StringRef Key, bool Required,
83     msgpack::ScalarNode::ScalarKind SKind,
84     function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
85   return verifyEntry(MapNode, Key, Required, [=](msgpack::Node &Node) {
86     return verifyScalar(Node, SKind, verifyValue);
87   });
88 }
89 
90 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapNode &MapNode,
91                                           StringRef Key, bool Required) {
92   return verifyEntry(MapNode, Key, Required, [this](msgpack::Node &Node) {
93     return verifyInteger(Node);
94   });
95 }
96 
97 bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
98   auto ArgsMapPtr = dyn_cast<msgpack::MapNode>(&Node);
99   if (!ArgsMapPtr)
100     return false;
101   auto &ArgsMap = *ArgsMapPtr;
102 
103   if (!verifyScalarEntry(ArgsMap, ".name", false,
104                          msgpack::ScalarNode::SK_String))
105     return false;
106   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
107                          msgpack::ScalarNode::SK_String))
108     return false;
109   if (!verifyIntegerEntry(ArgsMap, ".size", true))
110     return false;
111   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
112     return false;
113   if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
114                          msgpack::ScalarNode::SK_String,
115                          [](msgpack::ScalarNode &SNode) {
116                            return StringSwitch<bool>(SNode.getString())
117                                .Case("by_value", true)
118                                .Case("global_buffer", true)
119                                .Case("dynamic_shared_pointer", true)
120                                .Case("sampler", true)
121                                .Case("image", true)
122                                .Case("pipe", true)
123                                .Case("queue", true)
124                                .Case("hidden_global_offset_x", true)
125                                .Case("hidden_global_offset_y", true)
126                                .Case("hidden_global_offset_z", true)
127                                .Case("hidden_none", true)
128                                .Case("hidden_printf_buffer", true)
129                                .Case("hidden_default_queue", true)
130                                .Case("hidden_completion_action", true)
131                                .Default(false);
132                          }))
133     return false;
134   if (!verifyScalarEntry(ArgsMap, ".value_type", true,
135                          msgpack::ScalarNode::SK_String,
136                          [](msgpack::ScalarNode &SNode) {
137                            return StringSwitch<bool>(SNode.getString())
138                                .Case("struct", true)
139                                .Case("i8", true)
140                                .Case("u8", true)
141                                .Case("i16", true)
142                                .Case("u16", true)
143                                .Case("f16", true)
144                                .Case("i32", true)
145                                .Case("u32", true)
146                                .Case("f32", true)
147                                .Case("i64", true)
148                                .Case("u64", true)
149                                .Case("f64", true)
150                                .Default(false);
151                          }))
152     return false;
153   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
154     return false;
155   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
156                          msgpack::ScalarNode::SK_String,
157                          [](msgpack::ScalarNode &SNode) {
158                            return StringSwitch<bool>(SNode.getString())
159                                .Case("private", true)
160                                .Case("global", true)
161                                .Case("constant", true)
162                                .Case("local", true)
163                                .Case("generic", true)
164                                .Case("region", true)
165                                .Default(false);
166                          }))
167     return false;
168   if (!verifyScalarEntry(ArgsMap, ".access", false,
169                          msgpack::ScalarNode::SK_String,
170                          [](msgpack::ScalarNode &SNode) {
171                            return StringSwitch<bool>(SNode.getString())
172                                .Case("read_only", true)
173                                .Case("write_only", true)
174                                .Case("read_write", true)
175                                .Default(false);
176                          }))
177     return false;
178   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
179                          msgpack::ScalarNode::SK_String,
180                          [](msgpack::ScalarNode &SNode) {
181                            return StringSwitch<bool>(SNode.getString())
182                                .Case("read_only", true)
183                                .Case("write_only", true)
184                                .Case("read_write", true)
185                                .Default(false);
186                          }))
187     return false;
188   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
189                          msgpack::ScalarNode::SK_Boolean))
190     return false;
191   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
192                          msgpack::ScalarNode::SK_Boolean))
193     return false;
194   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
195                          msgpack::ScalarNode::SK_Boolean))
196     return false;
197   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
198                          msgpack::ScalarNode::SK_Boolean))
199     return false;
200 
201   return true;
202 }
203 
204 bool MetadataVerifier::verifyKernel(msgpack::Node &Node) {
205   auto KernelMapPtr = dyn_cast<msgpack::MapNode>(&Node);
206   if (!KernelMapPtr)
207     return false;
208   auto &KernelMap = *KernelMapPtr;
209 
210   if (!verifyScalarEntry(KernelMap, ".name", true,
211                          msgpack::ScalarNode::SK_String))
212     return false;
213   if (!verifyScalarEntry(KernelMap, ".symbol", true,
214                          msgpack::ScalarNode::SK_String))
215     return false;
216   if (!verifyScalarEntry(KernelMap, ".language", false,
217                          msgpack::ScalarNode::SK_String,
218                          [](msgpack::ScalarNode &SNode) {
219                            return StringSwitch<bool>(SNode.getString())
220                                .Case("OpenCL C", true)
221                                .Case("OpenCL C++", true)
222                                .Case("HCC", true)
223                                .Case("HIP", true)
224                                .Case("OpenMP", true)
225                                .Case("Assembler", true)
226                                .Default(false);
227                          }))
228     return false;
229   if (!verifyEntry(
230           KernelMap, ".language_version", false, [this](msgpack::Node &Node) {
231             return verifyArray(
232                 Node,
233                 [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
234           }))
235     return false;
236   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::Node &Node) {
237         return verifyArray(Node, [this](msgpack::Node &Node) {
238           return verifyKernelArgs(Node);
239         });
240       }))
241     return false;
242   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
243                    [this](msgpack::Node &Node) {
244                      return verifyArray(Node,
245                                         [this](msgpack::Node &Node) {
246                                           return verifyInteger(Node);
247                                         },
248                                         3);
249                    }))
250     return false;
251   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
252                    [this](msgpack::Node &Node) {
253                      return verifyArray(Node,
254                                         [this](msgpack::Node &Node) {
255                                           return verifyInteger(Node);
256                                         },
257                                         3);
258                    }))
259     return false;
260   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
261                          msgpack::ScalarNode::SK_String))
262     return false;
263   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
264                          msgpack::ScalarNode::SK_String))
265     return false;
266   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
267     return false;
268   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
269     return false;
270   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
271     return false;
272   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
273     return false;
274   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
275     return false;
276   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
277     return false;
278   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
279     return false;
280   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
281     return false;
282   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
283     return false;
284   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
285     return false;
286 
287   return true;
288 }
289 
290 bool MetadataVerifier::verify(msgpack::Node &HSAMetadataRoot) {
291   auto RootMapPtr = dyn_cast<msgpack::MapNode>(&HSAMetadataRoot);
292   if (!RootMapPtr)
293     return false;
294   auto &RootMap = *RootMapPtr;
295 
296   if (!verifyEntry(
297           RootMap, "amdhsa.version", true, [this](msgpack::Node &Node) {
298             return verifyArray(
299                 Node,
300                 [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
301           }))
302     return false;
303   if (!verifyEntry(
304           RootMap, "amdhsa.printf", false, [this](msgpack::Node &Node) {
305             return verifyArray(Node, [this](msgpack::Node &Node) {
306               return verifyScalar(Node, msgpack::ScalarNode::SK_String);
307             });
308           }))
309     return false;
310   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
311                    [this](msgpack::Node &Node) {
312                      return verifyArray(Node, [this](msgpack::Node &Node) {
313                        return verifyKernel(Node);
314                      });
315                    }))
316     return false;
317 
318   return true;
319 }
320 
321 } // end namespace V3
322 } // end namespace HSAMD
323 } // end namespace AMDGPU
324 } // end namespace llvm
325