1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Implements a verifier for AMDGPU HSA metadata. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h" 15 #include "llvm/ADT/StringSwitch.h" 16 #include "llvm/Support/AMDGPUMetadata.h" 17 18 namespace llvm { 19 namespace AMDGPU { 20 namespace HSAMD { 21 namespace V3 { 22 23 bool MetadataVerifier::verifyScalar( 24 msgpack::DocNode &Node, msgpack::Type SKind, 25 function_ref<bool(msgpack::DocNode &)> verifyValue) { 26 if (!Node.isScalar()) 27 return false; 28 if (Node.getKind() != SKind) { 29 if (Strict) 30 return false; 31 // If we are not strict, we interpret string values as "implicitly typed" 32 // and attempt to coerce them to the expected type here. 33 if (Node.getKind() != msgpack::Type::String) 34 return false; 35 StringRef StringValue = Node.getString(); 36 Node.fromString(StringValue); 37 if (Node.getKind() != SKind) 38 return false; 39 } 40 if (verifyValue) 41 return verifyValue(Node); 42 return true; 43 } 44 45 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) { 46 if (!verifyScalar(Node, msgpack::Type::UInt)) 47 if (!verifyScalar(Node, msgpack::Type::Int)) 48 return false; 49 return true; 50 } 51 52 bool MetadataVerifier::verifyArray( 53 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode, 54 Optional<size_t> Size) { 55 if (!Node.isArray()) 56 return false; 57 auto &Array = Node.getArray(); 58 if (Size && Array.size() != *Size) 59 return false; 60 return llvm::all_of(Array, verifyNode); 61 } 62 63 bool MetadataVerifier::verifyEntry( 64 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 65 function_ref<bool(msgpack::DocNode &)> verifyNode) { 66 auto Entry = MapNode.find(Key); 67 if (Entry == MapNode.end()) 68 return !Required; 69 return verifyNode(Entry->second); 70 } 71 72 bool MetadataVerifier::verifyScalarEntry( 73 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 74 msgpack::Type SKind, 75 function_ref<bool(msgpack::DocNode &)> verifyValue) { 76 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) { 77 return verifyScalar(Node, SKind, verifyValue); 78 }); 79 } 80 81 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode, 82 StringRef Key, bool Required) { 83 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) { 84 return verifyInteger(Node); 85 }); 86 } 87 88 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) { 89 if (!Node.isMap()) 90 return false; 91 auto &ArgsMap = Node.getMap(); 92 93 if (!verifyScalarEntry(ArgsMap, ".name", false, 94 msgpack::Type::String)) 95 return false; 96 if (!verifyScalarEntry(ArgsMap, ".type_name", false, 97 msgpack::Type::String)) 98 return false; 99 if (!verifyIntegerEntry(ArgsMap, ".size", true)) 100 return false; 101 if (!verifyIntegerEntry(ArgsMap, ".offset", true)) 102 return false; 103 if (!verifyScalarEntry(ArgsMap, ".value_kind", true, 104 msgpack::Type::String, 105 [](msgpack::DocNode &SNode) { 106 return StringSwitch<bool>(SNode.getString()) 107 .Case("by_value", true) 108 .Case("global_buffer", true) 109 .Case("dynamic_shared_pointer", true) 110 .Case("sampler", true) 111 .Case("image", true) 112 .Case("pipe", true) 113 .Case("queue", true) 114 .Case("hidden_global_offset_x", true) 115 .Case("hidden_global_offset_y", true) 116 .Case("hidden_global_offset_z", true) 117 .Case("hidden_none", true) 118 .Case("hidden_printf_buffer", true) 119 .Case("hidden_hostcall_buffer", true) 120 .Case("hidden_default_queue", true) 121 .Case("hidden_completion_action", true) 122 .Case("hidden_multigrid_sync_arg", true) 123 .Default(false); 124 })) 125 return false; 126 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false)) 127 return false; 128 if (!verifyScalarEntry(ArgsMap, ".address_space", false, 129 msgpack::Type::String, 130 [](msgpack::DocNode &SNode) { 131 return StringSwitch<bool>(SNode.getString()) 132 .Case("private", true) 133 .Case("global", true) 134 .Case("constant", true) 135 .Case("local", true) 136 .Case("generic", true) 137 .Case("region", true) 138 .Default(false); 139 })) 140 return false; 141 if (!verifyScalarEntry(ArgsMap, ".access", false, 142 msgpack::Type::String, 143 [](msgpack::DocNode &SNode) { 144 return StringSwitch<bool>(SNode.getString()) 145 .Case("read_only", true) 146 .Case("write_only", true) 147 .Case("read_write", true) 148 .Default(false); 149 })) 150 return false; 151 if (!verifyScalarEntry(ArgsMap, ".actual_access", false, 152 msgpack::Type::String, 153 [](msgpack::DocNode &SNode) { 154 return StringSwitch<bool>(SNode.getString()) 155 .Case("read_only", true) 156 .Case("write_only", true) 157 .Case("read_write", true) 158 .Default(false); 159 })) 160 return false; 161 if (!verifyScalarEntry(ArgsMap, ".is_const", false, 162 msgpack::Type::Boolean)) 163 return false; 164 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false, 165 msgpack::Type::Boolean)) 166 return false; 167 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false, 168 msgpack::Type::Boolean)) 169 return false; 170 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false, 171 msgpack::Type::Boolean)) 172 return false; 173 174 return true; 175 } 176 177 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) { 178 if (!Node.isMap()) 179 return false; 180 auto &KernelMap = Node.getMap(); 181 182 if (!verifyScalarEntry(KernelMap, ".name", true, 183 msgpack::Type::String)) 184 return false; 185 if (!verifyScalarEntry(KernelMap, ".symbol", true, 186 msgpack::Type::String)) 187 return false; 188 if (!verifyScalarEntry(KernelMap, ".language", false, 189 msgpack::Type::String, 190 [](msgpack::DocNode &SNode) { 191 return StringSwitch<bool>(SNode.getString()) 192 .Case("OpenCL C", true) 193 .Case("OpenCL C++", true) 194 .Case("HCC", true) 195 .Case("HIP", true) 196 .Case("OpenMP", true) 197 .Case("Assembler", true) 198 .Default(false); 199 })) 200 return false; 201 if (!verifyEntry( 202 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) { 203 return verifyArray( 204 Node, 205 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 206 })) 207 return false; 208 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) { 209 return verifyArray(Node, [this](msgpack::DocNode &Node) { 210 return verifyKernelArgs(Node); 211 }); 212 })) 213 return false; 214 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false, 215 [this](msgpack::DocNode &Node) { 216 return verifyArray(Node, 217 [this](msgpack::DocNode &Node) { 218 return verifyInteger(Node); 219 }, 220 3); 221 })) 222 return false; 223 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false, 224 [this](msgpack::DocNode &Node) { 225 return verifyArray(Node, 226 [this](msgpack::DocNode &Node) { 227 return verifyInteger(Node); 228 }, 229 3); 230 })) 231 return false; 232 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false, 233 msgpack::Type::String)) 234 return false; 235 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false, 236 msgpack::Type::String)) 237 return false; 238 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true)) 239 return false; 240 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true)) 241 return false; 242 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true)) 243 return false; 244 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true)) 245 return false; 246 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true)) 247 return false; 248 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true)) 249 return false; 250 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true)) 251 return false; 252 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true)) 253 return false; 254 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false)) 255 return false; 256 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false)) 257 return false; 258 259 return true; 260 } 261 262 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) { 263 if (!HSAMetadataRoot.isMap()) 264 return false; 265 auto &RootMap = HSAMetadataRoot.getMap(); 266 267 if (!verifyEntry( 268 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) { 269 return verifyArray( 270 Node, 271 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 272 })) 273 return false; 274 if (!verifyEntry( 275 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) { 276 return verifyArray(Node, [this](msgpack::DocNode &Node) { 277 return verifyScalar(Node, msgpack::Type::String); 278 }); 279 })) 280 return false; 281 if (!verifyEntry(RootMap, "amdhsa.kernels", true, 282 [this](msgpack::DocNode &Node) { 283 return verifyArray(Node, [this](msgpack::DocNode &Node) { 284 return verifyKernel(Node); 285 }); 286 })) 287 return false; 288 289 return true; 290 } 291 292 } // end namespace V3 293 } // end namespace HSAMD 294 } // end namespace AMDGPU 295 } // end namespace llvm 296