1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include "llvm/Support/AMDGPUMetadata.h"
17 
18 namespace llvm {
19 namespace AMDGPU {
20 namespace HSAMD {
21 namespace V3 {
22 
23 bool MetadataVerifier::verifyScalar(
24     msgpack::DocNode &Node, msgpack::Type SKind,
25     function_ref<bool(msgpack::DocNode &)> verifyValue) {
26   if (!Node.isScalar())
27     return false;
28   if (Node.getKind() != SKind) {
29     if (Strict)
30       return false;
31     // If we are not strict, we interpret string values as "implicitly typed"
32     // and attempt to coerce them to the expected type here.
33     if (Node.getKind() != msgpack::Type::String)
34       return false;
35     StringRef StringValue = Node.getString();
36     Node.fromString(StringValue);
37     if (Node.getKind() != SKind)
38       return false;
39   }
40   if (verifyValue)
41     return verifyValue(Node);
42   return true;
43 }
44 
45 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
46   if (!verifyScalar(Node, msgpack::Type::UInt))
47     if (!verifyScalar(Node, msgpack::Type::Int))
48       return false;
49   return true;
50 }
51 
52 bool MetadataVerifier::verifyArray(
53     msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
54     Optional<size_t> Size) {
55   if (!Node.isArray())
56     return false;
57   auto &Array = Node.getArray();
58   if (Size && Array.size() != *Size)
59     return false;
60   return llvm::all_of(Array, verifyNode);
61 }
62 
63 bool MetadataVerifier::verifyEntry(
64     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
65     function_ref<bool(msgpack::DocNode &)> verifyNode) {
66   auto Entry = MapNode.find(Key);
67   if (Entry == MapNode.end())
68     return !Required;
69   return verifyNode(Entry->second);
70 }
71 
72 bool MetadataVerifier::verifyScalarEntry(
73     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
74     msgpack::Type SKind,
75     function_ref<bool(msgpack::DocNode &)> verifyValue) {
76   return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
77     return verifyScalar(Node, SKind, verifyValue);
78   });
79 }
80 
81 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
82                                           StringRef Key, bool Required) {
83   return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
84     return verifyInteger(Node);
85   });
86 }
87 
88 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
89   if (!Node.isMap())
90     return false;
91   auto &ArgsMap = Node.getMap();
92 
93   if (!verifyScalarEntry(ArgsMap, ".name", false,
94                          msgpack::Type::String))
95     return false;
96   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
97                          msgpack::Type::String))
98     return false;
99   if (!verifyIntegerEntry(ArgsMap, ".size", true))
100     return false;
101   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
102     return false;
103   if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
104                          msgpack::Type::String,
105                          [](msgpack::DocNode &SNode) {
106                            return StringSwitch<bool>(SNode.getString())
107                                .Case("by_value", true)
108                                .Case("global_buffer", true)
109                                .Case("dynamic_shared_pointer", true)
110                                .Case("sampler", true)
111                                .Case("image", true)
112                                .Case("pipe", true)
113                                .Case("queue", true)
114                                .Case("hidden_global_offset_x", true)
115                                .Case("hidden_global_offset_y", true)
116                                .Case("hidden_global_offset_z", true)
117                                .Case("hidden_none", true)
118                                .Case("hidden_printf_buffer", true)
119                                .Case("hidden_hostcall_buffer", true)
120                                .Case("hidden_default_queue", true)
121                                .Case("hidden_completion_action", true)
122                                .Case("hidden_multigrid_sync_arg", true)
123                                .Default(false);
124                          }))
125     return false;
126   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
127     return false;
128   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
129                          msgpack::Type::String,
130                          [](msgpack::DocNode &SNode) {
131                            return StringSwitch<bool>(SNode.getString())
132                                .Case("private", true)
133                                .Case("global", true)
134                                .Case("constant", true)
135                                .Case("local", true)
136                                .Case("generic", true)
137                                .Case("region", true)
138                                .Default(false);
139                          }))
140     return false;
141   if (!verifyScalarEntry(ArgsMap, ".access", false,
142                          msgpack::Type::String,
143                          [](msgpack::DocNode &SNode) {
144                            return StringSwitch<bool>(SNode.getString())
145                                .Case("read_only", true)
146                                .Case("write_only", true)
147                                .Case("read_write", true)
148                                .Default(false);
149                          }))
150     return false;
151   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
152                          msgpack::Type::String,
153                          [](msgpack::DocNode &SNode) {
154                            return StringSwitch<bool>(SNode.getString())
155                                .Case("read_only", true)
156                                .Case("write_only", true)
157                                .Case("read_write", true)
158                                .Default(false);
159                          }))
160     return false;
161   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
162                          msgpack::Type::Boolean))
163     return false;
164   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
165                          msgpack::Type::Boolean))
166     return false;
167   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
168                          msgpack::Type::Boolean))
169     return false;
170   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
171                          msgpack::Type::Boolean))
172     return false;
173 
174   return true;
175 }
176 
177 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
178   if (!Node.isMap())
179     return false;
180   auto &KernelMap = Node.getMap();
181 
182   if (!verifyScalarEntry(KernelMap, ".name", true,
183                          msgpack::Type::String))
184     return false;
185   if (!verifyScalarEntry(KernelMap, ".symbol", true,
186                          msgpack::Type::String))
187     return false;
188   if (!verifyScalarEntry(KernelMap, ".language", false,
189                          msgpack::Type::String,
190                          [](msgpack::DocNode &SNode) {
191                            return StringSwitch<bool>(SNode.getString())
192                                .Case("OpenCL C", true)
193                                .Case("OpenCL C++", true)
194                                .Case("HCC", true)
195                                .Case("HIP", true)
196                                .Case("OpenMP", true)
197                                .Case("Assembler", true)
198                                .Default(false);
199                          }))
200     return false;
201   if (!verifyEntry(
202           KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
203             return verifyArray(
204                 Node,
205                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
206           }))
207     return false;
208   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
209         return verifyArray(Node, [this](msgpack::DocNode &Node) {
210           return verifyKernelArgs(Node);
211         });
212       }))
213     return false;
214   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
215                    [this](msgpack::DocNode &Node) {
216                      return verifyArray(Node,
217                                         [this](msgpack::DocNode &Node) {
218                                           return verifyInteger(Node);
219                                         },
220                                         3);
221                    }))
222     return false;
223   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
224                    [this](msgpack::DocNode &Node) {
225                      return verifyArray(Node,
226                                         [this](msgpack::DocNode &Node) {
227                                           return verifyInteger(Node);
228                                         },
229                                         3);
230                    }))
231     return false;
232   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
233                          msgpack::Type::String))
234     return false;
235   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
236                          msgpack::Type::String))
237     return false;
238   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
239     return false;
240   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
241     return false;
242   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
243     return false;
244   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
245     return false;
246   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
247     return false;
248   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
249     return false;
250   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
251     return false;
252   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
253     return false;
254   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
255     return false;
256   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
257     return false;
258 
259   return true;
260 }
261 
262 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
263   if (!HSAMetadataRoot.isMap())
264     return false;
265   auto &RootMap = HSAMetadataRoot.getMap();
266 
267   if (!verifyEntry(
268           RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
269             return verifyArray(
270                 Node,
271                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
272           }))
273     return false;
274   if (!verifyEntry(
275           RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
276             return verifyArray(Node, [this](msgpack::DocNode &Node) {
277               return verifyScalar(Node, msgpack::Type::String);
278             });
279           }))
280     return false;
281   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
282                    [this](msgpack::DocNode &Node) {
283                      return verifyArray(Node, [this](msgpack::DocNode &Node) {
284                        return verifyKernel(Node);
285                      });
286                    }))
287     return false;
288 
289   return true;
290 }
291 
292 } // end namespace V3
293 } // end namespace HSAMD
294 } // end namespace AMDGPU
295 } // end namespace llvm
296