1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5# See https://llvm.org/LICENSE.txt for license information. 6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 8# Script for updating SPIR-V dialect by scraping information from SPIR-V 9# HTML and JSON specs from the Internet. 10# 11# For example, to define the enum attribute for SPIR-V memory model: 12# 13# ./gen_spirv_dialect.py --base-td-path /path/to/SPIRVBase.td \ 14# --new-enum MemoryModel 15# 16# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported 17# SPIR-V enum classes. 18 19import itertools 20import re 21import requests 22import textwrap 23import yaml 24 25SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html' 26SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json' 27 28SPIRV_CL_EXT_HTML_SPEC_URL = 'https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html' 29SPIRV_CL_EXT_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json' 30 31AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n' 32AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!' 33AUTOGEN_OPCODE_SECTION_MARKER = ( 34 'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!') 35 36def get_spirv_doc_from_html_spec(url, settings): 37 """Extracts instruction documentation from SPIR-V HTML spec. 38 39 Returns: 40 - A dict mapping from instruction opcode to documentation. 41 """ 42 if url is None: 43 url = SPIRV_HTML_SPEC_URL 44 45 response = requests.get(url) 46 spec = response.content 47 48 from bs4 import BeautifulSoup 49 spirv = BeautifulSoup(spec, 'html.parser') 50 51 doc = {} 52 53 if settings.gen_cl_ops: 54 section_anchor = spirv.find('h2', {'id': '_binary_form'}) 55 for section in section_anchor.parent.find_all('div', {'class': 'sect2'}): 56 for table in section.find_all('table'): 57 inst_html = table.tbody.tr.td 58 opname = inst_html.a['id'] 59 # Ignore the first line, which is just the opname. 60 doc[opname] = inst_html.text.split('\n', 1)[1].strip() 61 else: 62 section_anchor = spirv.find('h3', {'id': '_instructions_3'}) 63 for section in section_anchor.parent.find_all('div', {'class': 'sect3'}): 64 for table in section.find_all('table'): 65 inst_html = table.tbody.tr.td.p 66 opname = inst_html.a['id'] 67 # Ignore the first line, which is just the opname. 68 doc[opname] = inst_html.text.split('\n', 1)[1].strip() 69 70 return doc 71 72 73def get_spirv_grammar_from_json_spec(url): 74 """Extracts operand kind and instruction grammar from SPIR-V JSON spec. 75 76 Returns: 77 - A list containing all operand kinds' grammar 78 - A list containing all instructions' grammar 79 """ 80 response = requests.get(SPIRV_JSON_SPEC_URL) 81 spec = response.content 82 83 import json 84 spirv = json.loads(spec) 85 86 if url is None: 87 return spirv['operand_kinds'], spirv['instructions'] 88 89 response_ext = requests.get(url) 90 spec_ext = response_ext.content 91 spirv_ext = json.loads(spec_ext) 92 93 return spirv['operand_kinds'], spirv_ext['instructions'] 94 95 96def split_list_into_sublists(items): 97 """Split the list of items into multiple sublists. 98 99 This is to make sure the string composed from each sublist won't exceed 100 80 characters. 101 102 Arguments: 103 - items: a list of strings 104 """ 105 chuncks = [] 106 chunk = [] 107 chunk_len = 0 108 109 for item in items: 110 chunk_len += len(item) + 2 111 if chunk_len > 80: 112 chuncks.append(chunk) 113 chunk = [] 114 chunk_len = len(item) + 2 115 chunk.append(item) 116 117 if len(chunk) != 0: 118 chuncks.append(chunk) 119 120 return chuncks 121 122 123def uniquify_enum_cases(lst): 124 """Prunes duplicate enum cases from the list. 125 126 Arguments: 127 - lst: List whose elements are to be uniqued. Assumes each element is a 128 (symbol, value) pair and elements already sorted according to value. 129 130 Returns: 131 - A list with all duplicates removed. The elements are sorted according to 132 value and, for each value, uniqued according to symbol. 133 original list, 134 - A map from deduplicated cases to the uniqued case. 135 """ 136 cases = lst 137 uniqued_cases = [] 138 duplicated_cases = {} 139 140 # First sort according to the value 141 cases.sort(key=lambda x: x[1]) 142 143 # Then group them according to the value 144 for _, groups in itertools.groupby(cases, key=lambda x: x[1]): 145 # For each value, sort according to the enumerant symbol. 146 sorted_group = sorted(groups, key=lambda x: x[0]) 147 # Keep the "smallest" case, which is typically the symbol without extension 148 # suffix. But we have special cases that we want to fix. 149 case = sorted_group[0] 150 for i in range(1, len(sorted_group)): 151 duplicated_cases[sorted_group[i][0]] = case[0] 152 if case[0] == 'HlslSemanticGOOGLE': 153 assert len(sorted_group) == 2, 'unexpected new variant for HlslSemantic' 154 case = sorted_group[1] 155 duplicated_cases[sorted_group[0][0]] = case[0] 156 uniqued_cases.append(case) 157 158 return uniqued_cases, duplicated_cases 159 160 161def toposort(dag, sort_fn): 162 """Topologically sorts the given dag. 163 164 Arguments: 165 - dag: a dict mapping from a node to its incoming nodes. 166 - sort_fn: a function for sorting nodes in the same batch. 167 168 Returns: 169 A list containing topologically sorted nodes. 170 """ 171 172 # Returns the next batch of nodes without incoming edges 173 def get_next_batch(dag): 174 while True: 175 no_prev_nodes = set(node for node, prev in dag.items() if not prev) 176 if not no_prev_nodes: 177 break 178 yield sorted(no_prev_nodes, key=sort_fn) 179 dag = { 180 node: (prev - no_prev_nodes) 181 for node, prev in dag.items() 182 if node not in no_prev_nodes 183 } 184 assert not dag, 'found cyclic dependency' 185 186 sorted_nodes = [] 187 for batch in get_next_batch(dag): 188 sorted_nodes.extend(batch) 189 190 return sorted_nodes 191 192 193def toposort_capabilities(all_cases, capability_mapping): 194 """Returns topologically sorted capability (symbol, value) pairs. 195 196 Arguments: 197 - all_cases: all capability cases (containing symbol, value, and implied 198 capabilities). 199 - capability_mapping: mapping from duplicated capability symbols to the 200 canonicalized symbol chosen for SPIRVBase.td. 201 202 Returns: 203 A list containing topologically sorted capability (symbol, value) pairs. 204 """ 205 dag = {} 206 name_to_value = {} 207 for case in all_cases: 208 # Get the current capability. 209 cur = case['enumerant'] 210 name_to_value[cur] = case['value'] 211 # Ignore duplicated symbols. 212 if cur in capability_mapping: 213 continue 214 215 # Get capabilities implied by the current capability. 216 prev = case.get('capabilities', []) 217 uniqued_prev = set([capability_mapping.get(c, c) for c in prev]) 218 dag[cur] = uniqued_prev 219 220 sorted_caps = toposort(dag, lambda x: name_to_value[x]) 221 # Attach the capability's value as the second component of the pair. 222 return [(c, name_to_value[c]) for c in sorted_caps] 223 224 225def get_capability_mapping(operand_kinds): 226 """Returns the capability mapping from duplicated cases to canonicalized ones. 227 228 Arguments: 229 - operand_kinds: all operand kinds' grammar spec 230 231 Returns: 232 - A map mapping from duplicated capability symbols to the canonicalized 233 symbol chosen for SPIRVBase.td. 234 """ 235 # Find the operand kind for capability 236 cap_kind = {} 237 for kind in operand_kinds: 238 if kind['kind'] == 'Capability': 239 cap_kind = kind 240 241 kind_cases = [ 242 (case['enumerant'], case['value']) for case in cap_kind['enumerants'] 243 ] 244 _, capability_mapping = uniquify_enum_cases(kind_cases) 245 246 return capability_mapping 247 248 249def get_availability_spec(enum_case, capability_mapping, for_op, for_cap): 250 """Returns the availability specification string for the given enum case. 251 252 Arguments: 253 - enum_case: the enum case to generate availability spec for. It may contain 254 'version', 'lastVersion', 'extensions', or 'capabilities'. 255 - capability_mapping: mapping from duplicated capability symbols to the 256 canonicalized symbol chosen for SPIRVBase.td. 257 - for_op: bool value indicating whether this is the availability spec for an 258 op itself. 259 - for_cap: bool value indicating whether this is the availability spec for 260 capabilities themselves. 261 262 Returns: 263 - A `let availability = [...];` string if with availability spec or 264 empty string if without availability spec 265 """ 266 assert not (for_op and for_cap), 'cannot set both for_op and for_cap' 267 268 DEFAULT_MIN_VERSION = 'MinVersion<SPV_V_1_0>' 269 DEFAULT_MAX_VERSION = 'MaxVersion<SPV_V_1_5>' 270 DEFAULT_CAP = 'Capability<[]>' 271 DEFAULT_EXT = 'Extension<[]>' 272 273 min_version = enum_case.get('version', '') 274 if min_version == 'None': 275 min_version = '' 276 elif min_version: 277 min_version = 'MinVersion<SPV_V_{}>'.format(min_version.replace('.', '_')) 278 # TODO: delete this once ODS can support dialect-specific content 279 # and we can use omission to mean no requirements. 280 if for_op and not min_version: 281 min_version = DEFAULT_MIN_VERSION 282 283 max_version = enum_case.get('lastVersion', '') 284 if max_version: 285 max_version = 'MaxVersion<SPV_V_{}>'.format(max_version.replace('.', '_')) 286 # TODO: delete this once ODS can support dialect-specific content 287 # and we can use omission to mean no requirements. 288 if for_op and not max_version: 289 max_version = DEFAULT_MAX_VERSION 290 291 exts = enum_case.get('extensions', []) 292 if exts: 293 exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts)))) 294 # We need to strip the minimal version requirement if this symbol is 295 # available via an extension, which means *any* SPIR-V version can support 296 # it as long as the extension is provided. The grammar's 'version' field 297 # under such case should be interpreted as this symbol is introduced as 298 # a core symbol since the given version, rather than a minimal version 299 # requirement. 300 min_version = DEFAULT_MIN_VERSION if for_op else '' 301 # TODO: delete this once ODS can support dialect-specific content 302 # and we can use omission to mean no requirements. 303 if for_op and not exts: 304 exts = DEFAULT_EXT 305 306 caps = enum_case.get('capabilities', []) 307 implies = '' 308 if caps: 309 canonicalized_caps = [] 310 for c in caps: 311 if c in capability_mapping: 312 canonicalized_caps.append(capability_mapping[c]) 313 else: 314 canonicalized_caps.append(c) 315 prefixed_caps = [ 316 'SPV_C_{}'.format(c) for c in sorted(set(canonicalized_caps)) 317 ] 318 if for_cap: 319 # If this is generating the availability for capabilities, we need to 320 # put the capability "requirements" in implies field because now 321 # the "capabilities" field in the source grammar means so. 322 caps = '' 323 implies = 'list<I32EnumAttrCase> implies = [{}];'.format( 324 ', '.join(prefixed_caps)) 325 else: 326 caps = 'Capability<[{}]>'.format(', '.join(prefixed_caps)) 327 implies = '' 328 # TODO: delete this once ODS can support dialect-specific content 329 # and we can use omission to mean no requirements. 330 if for_op and not caps: 331 caps = DEFAULT_CAP 332 333 avail = '' 334 # Compose availability spec if any of the requirements is not empty. 335 # For ops, because we have a default in SPV_Op class, omit if the spec 336 # is the same. 337 if (min_version or max_version or caps or exts) and not ( 338 for_op and min_version == DEFAULT_MIN_VERSION and 339 max_version == DEFAULT_MAX_VERSION and caps == DEFAULT_CAP and 340 exts == DEFAULT_EXT): 341 joined_spec = ',\n '.join( 342 [e for e in [min_version, max_version, exts, caps] if e]) 343 avail = '{} availability = [\n {}\n ];'.format( 344 'let' if for_op else 'list<Availability>', joined_spec) 345 346 return '{}{}{}'.format(implies, '\n ' if implies and avail else '', avail) 347 348 349def gen_operand_kind_enum_attr(operand_kind, capability_mapping): 350 """Generates the TableGen EnumAttr definition for the given operand kind. 351 352 Returns: 353 - The operand kind's name 354 - A string containing the TableGen EnumAttr definition 355 """ 356 if 'enumerants' not in operand_kind: 357 return '', '' 358 359 # Returns a symbol for the given case in the given kind. This function 360 # handles Dim specially to avoid having numbers as the start of symbols, 361 # which does not play well with C++ and the MLIR parser. 362 def get_case_symbol(kind_name, case_name): 363 if kind_name == 'Dim': 364 if case_name == '1D' or case_name == '2D' or case_name == '3D': 365 return 'Dim{}'.format(case_name) 366 return case_name 367 368 kind_name = operand_kind['kind'] 369 is_bit_enum = operand_kind['category'] == 'BitEnum' 370 kind_category = 'Bit' if is_bit_enum else 'I32' 371 kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) 372 373 name_to_case_dict = {} 374 for case in operand_kind['enumerants']: 375 name_to_case_dict[case['enumerant']] = case 376 377 if kind_name == 'Capability': 378 # Special treatment for capability cases: we need to sort them topologically 379 # because a capability can refer to another via the 'implies' field. 380 kind_cases = toposort_capabilities(operand_kind['enumerants'], 381 capability_mapping) 382 else: 383 kind_cases = [(case['enumerant'], case['value']) 384 for case in operand_kind['enumerants']] 385 kind_cases, _ = uniquify_enum_cases(kind_cases) 386 max_len = max([len(symbol) for (symbol, _) in kind_cases]) 387 388 # Generate the definition for each enum case 389 fmt_str = 'def SPV_{acronym}_{case} {colon:>{offset}} '\ 390 '{category}EnumAttrCase<"{symbol}", {value}>{avail}' 391 case_defs = [] 392 for case in kind_cases: 393 avail = get_availability_spec(name_to_case_dict[case[0]], 394 capability_mapping, 395 False, kind_name == 'Capability') 396 case_def = fmt_str.format( 397 category=kind_category, 398 acronym=kind_acronym, 399 case=case[0], 400 symbol=get_case_symbol(kind_name, case[0]), 401 value=case[1], 402 avail=' {{\n {}\n}}'.format(avail) if avail else ';', 403 colon=':', 404 offset=(max_len + 1 - len(case[0]))) 405 case_defs.append(case_def) 406 case_defs = '\n'.join(case_defs) 407 408 # Generate the list of enum case names 409 fmt_str = 'SPV_{acronym}_{symbol}'; 410 case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0]) 411 for case in kind_cases] 412 413 # Split them into sublists and concatenate into multiple lines 414 case_names = split_list_into_sublists(case_names) 415 case_names = ['{:6}'.format('') + ', '.join(sublist) 416 for sublist in case_names] 417 case_names = ',\n'.join(case_names) 418 419 # Generate the enum attribute definition 420 enum_attr = '''def SPV_{name}Attr : 421 SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", [ 422{cases} 423 ]>;'''.format( 424 name=kind_name, category=kind_category, cases=case_names) 425 return kind_name, case_defs + '\n\n' + enum_attr 426 427 428def gen_opcode(instructions): 429 """ Generates the TableGen definition to map opname to opcode 430 431 Returns: 432 - A string containing the TableGen SPV_OpCode definition 433 """ 434 435 max_len = max([len(inst['opname']) for inst in instructions]) 436 def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\ 437 'I32EnumAttrCase<"{name}", {value}>;' 438 opcode_defs = [ 439 def_fmt_str.format( 440 name=inst['opname'], 441 value=inst['opcode'], 442 colon=':', 443 offset=(max_len + 1 - len(inst['opname']))) for inst in instructions 444 ] 445 opcode_str = '\n'.join(opcode_defs) 446 447 decl_fmt_str = 'SPV_OC_{name}' 448 opcode_list = [ 449 decl_fmt_str.format(name=inst['opname']) for inst in instructions 450 ] 451 opcode_list = split_list_into_sublists(opcode_list) 452 opcode_list = [ 453 '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list 454 ] 455 opcode_list = ',\n'.join(opcode_list) 456 enum_attr = 'def SPV_OpcodeAttr :\n'\ 457 ' SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\ 458 '{lst}\n'\ 459 ' ]>;'.format(name='Opcode', lst=opcode_list) 460 return opcode_str + '\n\n' + enum_attr 461 462def map_cap_to_opnames(instructions): 463 """Maps capabilities to instructions enabled by those capabilities 464 465 Arguments: 466 - instructions: a list containing a subset of SPIR-V instructions' grammar 467 Returns: 468 - A map with keys representing capabilities and values of lists of 469 instructions enabled by the corresponding key 470 """ 471 cap_to_inst = {} 472 473 for inst in instructions: 474 caps = inst['capabilities'] if 'capabilities' in inst else ['0_core_0'] 475 for cap in caps: 476 if cap not in cap_to_inst: 477 cap_to_inst[cap] = [] 478 cap_to_inst[cap].append(inst['opname']) 479 480 return cap_to_inst 481 482def gen_instr_coverage_report(path, instructions): 483 """Dumps to standard output a YAML report of current instruction coverage 484 485 Arguments: 486 - path: the path to SPIRBase.td 487 - instructions: a list containing all SPIR-V instructions' grammar 488 """ 489 with open(path, 'r') as f: 490 content = f.read() 491 492 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) 493 494 existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])] 495 existing_instructions = list( 496 filter(lambda inst: (inst['opname'] in existing_opcodes), 497 instructions)) 498 499 instructions_opnames = [inst['opname'] for inst in instructions] 500 501 remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes)) 502 remaining_instructions = list( 503 filter(lambda inst: (inst['opname'] in remaining_opcodes), 504 instructions)) 505 506 rem_cap_to_instr = map_cap_to_opnames(remaining_instructions) 507 ex_cap_to_instr = map_cap_to_opnames(existing_instructions) 508 509 rem_cap_to_cov = {} 510 511 # Calculate coverage for each capability 512 for cap in rem_cap_to_instr: 513 if cap not in ex_cap_to_instr: 514 rem_cap_to_cov[cap] = 0.0 515 else: 516 rem_cap_to_cov[cap] = \ 517 (len(ex_cap_to_instr[cap]) / (len(ex_cap_to_instr[cap]) \ 518 + len(rem_cap_to_instr[cap]))) 519 520 report = {} 521 522 # Merge the 3 maps into one report 523 for cap in rem_cap_to_instr: 524 report[cap] = {} 525 report[cap]['Supported Instructions'] = \ 526 ex_cap_to_instr[cap] if cap in ex_cap_to_instr else [] 527 report[cap]['Unsupported Instructions'] = rem_cap_to_instr[cap] 528 report[cap]['Coverage'] = '{}%'.format(int(rem_cap_to_cov[cap] * 100)) 529 530 print(yaml.dump(report)) 531 532def update_td_opcodes(path, instructions, filter_list): 533 """Updates SPIRBase.td with new generated opcode cases. 534 535 Arguments: 536 - path: the path to SPIRBase.td 537 - instructions: a list containing all SPIR-V instructions' grammar 538 - filter_list: a list containing new opnames to add 539 """ 540 541 with open(path, 'r') as f: 542 content = f.read() 543 544 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) 545 assert len(content) == 3 546 547 # Extend opcode list with existing list 548 existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])] 549 filter_list.extend(existing_opcodes) 550 filter_list = list(set(filter_list)) 551 552 # Generate the opcode for all instructions in SPIR-V 553 filter_instrs = list( 554 filter(lambda inst: (inst['opname'] in filter_list), instructions)) 555 # Sort instruction based on opcode 556 filter_instrs.sort(key=lambda inst: inst['opcode']) 557 opcode = gen_opcode(filter_instrs) 558 559 # Substitute the opcode 560 content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \ 561 opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \ 562 + content[2] 563 564 with open(path, 'w') as f: 565 f.write(content) 566 567 568def update_td_enum_attrs(path, operand_kinds, filter_list): 569 """Updates SPIRBase.td with new generated enum definitions. 570 571 Arguments: 572 - path: the path to SPIRBase.td 573 - operand_kinds: a list containing all operand kinds' grammar 574 - filter_list: a list containing new enums to add 575 """ 576 with open(path, 'r') as f: 577 content = f.read() 578 579 content = content.split(AUTOGEN_ENUM_SECTION_MARKER) 580 assert len(content) == 3 581 582 # Extend filter list with existing enum definitions 583 existing_kinds = [ 584 k[8:-4] for k in re.findall('def SPV_\w+Attr', content[1])] 585 filter_list.extend(existing_kinds) 586 587 capability_mapping = get_capability_mapping(operand_kinds) 588 589 # Generate definitions for all enums in filter list 590 defs = [ 591 gen_operand_kind_enum_attr(kind, capability_mapping) 592 for kind in operand_kinds 593 if kind['kind'] in filter_list 594 ] 595 # Sort alphabetically according to enum name 596 defs.sort(key=lambda enum : enum[0]) 597 # Only keep the definitions from now on 598 # Put Capability's definition at the very beginning because capability cases 599 # will be referenced later 600 defs = [enum[1] for enum in defs if enum[0] == 'Capability' 601 ] + [enum[1] for enum in defs if enum[0] != 'Capability'] 602 603 # Substitute the old section 604 content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \ 605 '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER \ 606 + content[2]; 607 608 with open(path, 'w') as f: 609 f.write(content) 610 611 612def snake_casify(name): 613 """Turns the given name to follow snake_case convention.""" 614 name = re.sub('\W+', '', name).split() 615 name = [s.lower() for s in name] 616 return '_'.join(name) 617 618 619def map_spec_operand_to_ods_argument(operand): 620 """Maps an operand in SPIR-V JSON spec to an op argument in ODS. 621 622 Arguments: 623 - A dict containing the operand's kind, quantifier, and name 624 625 Returns: 626 - A string containing both the type and name for the argument 627 """ 628 kind = operand['kind'] 629 quantifier = operand.get('quantifier', '') 630 631 # These instruction "operands" are for encoding the results; they should 632 # not be handled here. 633 assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind' 634 assert kind != 'IdResult', 'unexpected to handle "IdResult" kind' 635 636 if kind == 'IdRef': 637 if quantifier == '': 638 arg_type = 'SPV_Type' 639 elif quantifier == '?': 640 arg_type = 'Optional<SPV_Type>' 641 else: 642 arg_type = 'Variadic<SPV_Type>' 643 elif kind == 'IdMemorySemantics' or kind == 'IdScope': 644 # TODO: Need to further constrain 'IdMemorySemantics' 645 # and 'IdScope' given that they should be generated from OpConstant. 646 assert quantifier == '', ('unexpected to have optional/variadic memory ' 647 'semantics or scope <id>') 648 arg_type = 'SPV_' + kind[2:] + 'Attr' 649 elif kind == 'LiteralInteger': 650 if quantifier == '': 651 arg_type = 'I32Attr' 652 elif quantifier == '?': 653 arg_type = 'OptionalAttr<I32Attr>' 654 else: 655 arg_type = 'OptionalAttr<I32ArrayAttr>' 656 elif kind == 'LiteralString' or \ 657 kind == 'LiteralContextDependentNumber' or \ 658 kind == 'LiteralExtInstInteger' or \ 659 kind == 'LiteralSpecConstantOpInteger' or \ 660 kind == 'PairLiteralIntegerIdRef' or \ 661 kind == 'PairIdRefLiteralInteger' or \ 662 kind == 'PairIdRefIdRef': 663 assert False, '"{}" kind unimplemented'.format(kind) 664 else: 665 # The rest are all enum operands that we represent with op attributes. 666 assert quantifier != '*', 'unexpected to have variadic enum attribute' 667 arg_type = 'SPV_{}Attr'.format(kind) 668 if quantifier == '?': 669 arg_type = 'OptionalAttr<{}>'.format(arg_type) 670 671 name = operand.get('name', '') 672 name = snake_casify(name) if name else kind.lower() 673 674 return '{}:${}'.format(arg_type, name) 675 676 677def get_description(text, appendix): 678 """Generates the description for the given SPIR-V instruction. 679 680 Arguments: 681 - text: Textual description of the operation as string. 682 - appendix: Additional contents to attach in description as string, 683 includking IR examples, and others. 684 685 Returns: 686 - A string that corresponds to the description of the Tablegen op. 687 """ 688 fmt_str = '{text}\n\n <!-- End of AutoGen section -->\n{appendix}\n ' 689 return fmt_str.format(text=text, appendix=appendix) 690 691 692def get_op_definition(instruction, opname, doc, existing_info, capability_mapping, settings): 693 """Generates the TableGen op definition for the given SPIR-V instruction. 694 695 Arguments: 696 - instruction: the instruction's SPIR-V JSON grammar 697 - doc: the instruction's SPIR-V HTML doc 698 - existing_info: a dict containing potential manually specified sections for 699 this instruction 700 - capability_mapping: mapping from duplicated capability symbols to the 701 canonicalized symbol chosen for SPIRVBase.td 702 703 Returns: 704 - A string containing the TableGen op definition 705 """ 706 if settings.gen_cl_ops: 707 fmt_str = ('def SPV_{opname}Op : ' 708 'SPV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > ' 709 '{{\n let summary = {summary};\n\n let description = ' 710 '[{{\n{description}}}];{availability}\n') 711 else: 712 fmt_str = ('def SPV_{opname_src}Op : ' 713 'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> ' 714 '{{\n let summary = {summary};\n\n let description = ' 715 '[{{\n{description}}}];{availability}\n') 716 717 inst_category = existing_info.get('inst_category', 'Op') 718 if inst_category == 'Op': 719 fmt_str +='\n let arguments = (ins{args});\n\n'\ 720 ' let results = (outs{results});\n' 721 722 fmt_str +='{extras}'\ 723 '}}\n' 724 725 opname_src = instruction['opname'] 726 if opname.startswith('Op'): 727 opname_src = opname_src[2:] 728 729 category_args = existing_info.get('category_args', '') 730 731 if '\n' in doc: 732 summary, text = doc.split('\n', 1) 733 else: 734 summary = doc 735 text = '' 736 wrapper = textwrap.TextWrapper( 737 width=76, initial_indent=' ', subsequent_indent=' ') 738 739 # Format summary. If the summary can fit in the same line, we print it out 740 # as a "-quoted string; otherwise, wrap the lines using "[{...}]". 741 summary = summary.strip(); 742 if len(summary) + len(' let summary = "";') <= 80: 743 summary = '"{}"'.format(summary) 744 else: 745 summary = '[{{\n{}\n }}]'.format(wrapper.fill(summary)) 746 747 # Wrap text 748 text = text.split('\n') 749 text = [wrapper.fill(line) for line in text if line] 750 text = '\n\n'.join(text) 751 752 operands = instruction.get('operands', []) 753 754 # Op availability 755 avail = get_availability_spec(instruction, capability_mapping, True, False) 756 if avail: 757 avail = '\n\n {0}'.format(avail) 758 759 # Set op's result 760 results = '' 761 if len(operands) > 0 and operands[0]['kind'] == 'IdResultType': 762 results = '\n SPV_Type:$result\n ' 763 operands = operands[1:] 764 if 'results' in existing_info: 765 results = existing_info['results'] 766 767 # Ignore the operand standing for the result <id> 768 if len(operands) > 0 and operands[0]['kind'] == 'IdResult': 769 operands = operands[1:] 770 771 # Set op' argument 772 arguments = existing_info.get('arguments', None) 773 if arguments is None: 774 arguments = [map_spec_operand_to_ods_argument(o) for o in operands] 775 arguments = ',\n '.join(arguments) 776 if arguments: 777 # Prepend and append whitespace for formatting 778 arguments = '\n {}\n '.format(arguments) 779 780 description = existing_info.get('description', None) 781 if description is None: 782 assembly = '\n ```\n'\ 783 ' [TODO]\n'\ 784 ' ```mlir\n\n'\ 785 ' #### Example:\n\n'\ 786 ' ```\n'\ 787 ' [TODO]\n' \ 788 ' ```' 789 description = get_description(text, assembly) 790 791 return fmt_str.format( 792 opname=opname, 793 opname_src=opname_src, 794 opcode=instruction['opcode'], 795 category_args=category_args, 796 inst_category=inst_category, 797 traits=existing_info.get('traits', ''), 798 summary=summary, 799 description=description, 800 availability=avail, 801 args=arguments, 802 results=results, 803 extras=existing_info.get('extras', '')) 804 805 806def get_string_between(base, start, end): 807 """Extracts a substring with a specified start and end from a string. 808 809 Arguments: 810 - base: string to extract from. 811 - start: string to use as the start of the substring. 812 - end: string to use as the end of the substring. 813 814 Returns: 815 - The substring if found 816 - The part of the base after end of the substring. Is the base string itself 817 if the substring wasnt found. 818 """ 819 split = base.split(start, 1) 820 if len(split) == 2: 821 rest = split[1].split(end, 1) 822 assert len(rest) == 2, \ 823 'cannot find end "{end}" while extracting substring '\ 824 'starting with {start}'.format(start=start, end=end) 825 return rest[0].rstrip(end), rest[1] 826 return '', split[0] 827 828 829def get_string_between_nested(base, start, end): 830 """Extracts a substring with a nested start and end from a string. 831 832 Arguments: 833 - base: string to extract from. 834 - start: string to use as the start of the substring. 835 - end: string to use as the end of the substring. 836 837 Returns: 838 - The substring if found 839 - The part of the base after end of the substring. Is the base string itself 840 if the substring wasn't found. 841 """ 842 split = base.split(start, 1) 843 if len(split) == 2: 844 # Handle nesting delimiters 845 rest = split[1] 846 unmatched_start = 1 847 index = 0 848 while unmatched_start > 0 and index < len(rest): 849 if rest[index:].startswith(end): 850 unmatched_start -= 1 851 if unmatched_start == 0: 852 break 853 index += len(end) 854 elif rest[index:].startswith(start): 855 unmatched_start += 1 856 index += len(start) 857 else: 858 index += 1 859 860 assert index < len(rest), \ 861 'cannot find end "{end}" while extracting substring '\ 862 'starting with "{start}"'.format(start=start, end=end) 863 return rest[:index], rest[index + len(end):] 864 return '', split[0] 865 866 867def extract_td_op_info(op_def): 868 """Extracts potentially manually specified sections in op's definition. 869 870 Arguments: - A string containing the op's TableGen definition 871 872 Returns: 873 - A dict containing potential manually specified sections 874 """ 875 # Get opname 876 opname = [o[8:-2] for o in re.findall('def SPV_\w+Op', op_def)] 877 assert len(opname) == 1, 'more than one ops in the same section!' 878 opname = opname[0] 879 880 # Get instruction category 881 inst_category = [ 882 o[4:] for o in re.findall('SPV_\w+Op', 883 op_def.split(':', 1)[1]) 884 ] 885 assert len(inst_category) <= 1, 'more than one ops in the same section!' 886 inst_category = inst_category[0] if len(inst_category) == 1 else 'Op' 887 888 # Get category_args 889 op_tmpl_params, _ = get_string_between_nested(op_def, '<', '>') 890 opstringname, rest = get_string_between(op_tmpl_params, '"', '"') 891 category_args = rest.split('[', 1)[0] 892 893 # Get traits 894 traits, _ = get_string_between_nested(rest, '[', ']') 895 896 # Get description 897 description, rest = get_string_between(op_def, 'let description = [{\n', 898 '}];\n') 899 900 # Get arguments 901 args, rest = get_string_between(rest, ' let arguments = (ins', ');\n') 902 903 # Get results 904 results, rest = get_string_between(rest, ' let results = (outs', ');\n') 905 906 extras = rest.strip(' }\n') 907 if extras: 908 extras = '\n {}\n'.format(extras) 909 910 return { 911 # Prefix with 'Op' to make it consistent with SPIR-V spec 912 'opname': 'Op{}'.format(opname), 913 'inst_category': inst_category, 914 'category_args': category_args, 915 'traits': traits, 916 'description': description, 917 'arguments': args, 918 'results': results, 919 'extras': extras 920 } 921 922 923def update_td_op_definitions(path, instructions, docs, filter_list, 924 inst_category, capability_mapping, settings): 925 """Updates SPIRVOps.td with newly generated op definition. 926 927 Arguments: 928 - path: path to SPIRVOps.td 929 - instructions: SPIR-V JSON grammar for all instructions 930 - docs: SPIR-V HTML doc for all instructions 931 - filter_list: a list containing new opnames to include 932 - capability_mapping: mapping from duplicated capability symbols to the 933 canonicalized symbol chosen for SPIRVBase.td. 934 935 Returns: 936 - A string containing all the TableGen op definitions 937 """ 938 with open(path, 'r') as f: 939 content = f.read() 940 941 # Split the file into chunks, each containing one op. 942 ops = content.split(AUTOGEN_OP_DEF_SEPARATOR) 943 header = ops[0] 944 footer = ops[-1] 945 ops = ops[1:-1] 946 947 # For each existing op, extract the manually-written sections out to retain 948 # them when re-generating the ops. Also append the existing ops to filter 949 # list. 950 name_op_map = {} # Map from opname to its existing ODS definition 951 op_info_dict = {} 952 for op in ops: 953 info_dict = extract_td_op_info(op) 954 opname = info_dict['opname'] 955 name_op_map[opname] = op 956 op_info_dict[opname] = info_dict 957 filter_list.append(opname) 958 filter_list = sorted(list(set(filter_list))) 959 960 op_defs = [] 961 962 if settings.gen_cl_ops: 963 fix_opname = lambda src: src.replace('CL','').lower() 964 else: 965 fix_opname = lambda src: src 966 967 for opname in filter_list: 968 # Find the grammar spec for this op 969 try: 970 fixed_opname = fix_opname(opname) 971 instruction = next( 972 inst for inst in instructions if inst['opname'] == fixed_opname) 973 974 op_defs.append( 975 get_op_definition( 976 instruction, opname, docs[fixed_opname], 977 op_info_dict.get(opname, {'inst_category': inst_category}), 978 capability_mapping, settings)) 979 except StopIteration: 980 # This is an op added by us; use the existing ODS definition. 981 op_defs.append(name_op_map[opname]) 982 983 # Substitute the old op definitions 984 op_defs = [header] + op_defs + [footer] 985 content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs) 986 987 with open(path, 'w') as f: 988 f.write(content) 989 990 991if __name__ == '__main__': 992 import argparse 993 994 cli_parser = argparse.ArgumentParser( 995 description='Update SPIR-V dialect definitions using SPIR-V spec') 996 997 cli_parser.add_argument( 998 '--base-td-path', 999 dest='base_td_path', 1000 type=str, 1001 default=None, 1002 help='Path to SPIRVBase.td') 1003 cli_parser.add_argument( 1004 '--op-td-path', 1005 dest='op_td_path', 1006 type=str, 1007 default=None, 1008 help='Path to SPIRVOps.td') 1009 1010 cli_parser.add_argument( 1011 '--new-enum', 1012 dest='new_enum', 1013 type=str, 1014 default=None, 1015 help='SPIR-V enum to be added to SPIRVBase.td') 1016 cli_parser.add_argument( 1017 '--new-opcodes', 1018 dest='new_opcodes', 1019 type=str, 1020 default=None, 1021 nargs='*', 1022 help='update SPIR-V opcodes in SPIRVBase.td') 1023 cli_parser.add_argument( 1024 '--new-inst', 1025 dest='new_inst', 1026 type=str, 1027 default=None, 1028 nargs='*', 1029 help='SPIR-V instruction to be added to ops file') 1030 cli_parser.add_argument( 1031 '--inst-category', 1032 dest='inst_category', 1033 type=str, 1034 default='Op', 1035 help='SPIR-V instruction category used for choosing '\ 1036 'the TableGen base class to define this op') 1037 cli_parser.add_argument( 1038 '--gen-cl-ops', 1039 dest='gen_cl_ops', 1040 help='Generate OpenCL Extended Instruction Set op', 1041 action='store_true') 1042 cli_parser.set_defaults(gen_cl_ops=False) 1043 cli_parser.add_argument('--gen-inst-coverage', dest='gen_inst_coverage', action='store_true') 1044 cli_parser.set_defaults(gen_inst_coverage=False) 1045 1046 args = cli_parser.parse_args() 1047 1048 if args.gen_cl_ops: 1049 ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL 1050 ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL 1051 else: 1052 ext_html_url = None 1053 ext_json_url = None 1054 1055 operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url) 1056 1057 # Define new enum attr 1058 if args.new_enum is not None: 1059 assert args.base_td_path is not None 1060 filter_list = [args.new_enum] if args.new_enum else [] 1061 update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list) 1062 1063 # Define new opcode 1064 if args.new_opcodes is not None: 1065 assert args.base_td_path is not None 1066 update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) 1067 1068 # Define new op 1069 if args.new_inst is not None: 1070 assert args.op_td_path is not None 1071 docs = get_spirv_doc_from_html_spec(ext_html_url, args) 1072 capability_mapping = get_capability_mapping(operand_kinds) 1073 update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst, 1074 args.inst_category, capability_mapping, args) 1075 print('Done. Note that this script just generates a template; ', end='') 1076 print('please read the spec and update traits, arguments, and ', end='') 1077 print('results accordingly.') 1078 1079 if args.gen_inst_coverage: 1080 gen_instr_coverage_report(args.base_td_path, instructions) 1081