1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5# See https://llvm.org/LICENSE.txt for license information.
6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7
8# Script for updating SPIR-V dialect by scraping information from SPIR-V
9# HTML and JSON specs from the Internet.
10#
11# For example, to define the enum attribute for SPIR-V memory model:
12#
13# ./gen_spirv_dialect.py --base-td-path /path/to/SPIRVBase.td \
14#                        --new-enum MemoryModel
15#
16# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
17# SPIR-V enum classes.
18
19import itertools
20import re
21import requests
22import textwrap
23import yaml
24
25SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html'
26SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json'
27
28SPIRV_CL_EXT_HTML_SPEC_URL = 'https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html'
29SPIRV_CL_EXT_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json'
30
31AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n'
32AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!'
33AUTOGEN_OPCODE_SECTION_MARKER = (
34    'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!')
35
36def get_spirv_doc_from_html_spec(url, settings):
37  """Extracts instruction documentation from SPIR-V HTML spec.
38
39  Returns:
40    - A dict mapping from instruction opcode to documentation.
41  """
42  if url is None:
43    url = SPIRV_HTML_SPEC_URL
44
45  response = requests.get(url)
46  spec = response.content
47
48  from bs4 import BeautifulSoup
49  spirv = BeautifulSoup(spec, 'html.parser')
50
51  doc = {}
52
53  if settings.gen_cl_ops:
54    section_anchor = spirv.find('h2', {'id': '_binary_form'})
55    for section in section_anchor.parent.find_all('div', {'class': 'sect2'}):
56      for table in section.find_all('table'):
57        inst_html = table.tbody.tr.td
58        opname = inst_html.a['id']
59        # Ignore the first line, which is just the opname.
60        doc[opname] = inst_html.text.split('\n', 1)[1].strip()
61  else:
62    section_anchor = spirv.find('h3', {'id': '_instructions_3'})
63    for section in section_anchor.parent.find_all('div', {'class': 'sect3'}):
64      for table in section.find_all('table'):
65        inst_html = table.tbody.tr.td.p
66        opname = inst_html.a['id']
67        # Ignore the first line, which is just the opname.
68        doc[opname] = inst_html.text.split('\n', 1)[1].strip()
69
70  return doc
71
72
73def get_spirv_grammar_from_json_spec(url):
74  """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
75
76  Returns:
77    - A list containing all operand kinds' grammar
78    - A list containing all instructions' grammar
79  """
80  response = requests.get(SPIRV_JSON_SPEC_URL)
81  spec = response.content
82
83  import json
84  spirv = json.loads(spec)
85
86  if url is None:
87    return spirv['operand_kinds'], spirv['instructions']
88
89  response_ext = requests.get(url)
90  spec_ext = response_ext.content
91  spirv_ext = json.loads(spec_ext)
92
93  return spirv['operand_kinds'], spirv_ext['instructions']
94
95
96def split_list_into_sublists(items):
97  """Split the list of items into multiple sublists.
98
99  This is to make sure the string composed from each sublist won't exceed
100  80 characters.
101
102  Arguments:
103    - items: a list of strings
104  """
105  chuncks = []
106  chunk = []
107  chunk_len = 0
108
109  for item in items:
110    chunk_len += len(item) + 2
111    if chunk_len > 80:
112      chuncks.append(chunk)
113      chunk = []
114      chunk_len = len(item) + 2
115    chunk.append(item)
116
117  if len(chunk) != 0:
118    chuncks.append(chunk)
119
120  return chuncks
121
122
123def uniquify_enum_cases(lst):
124  """Prunes duplicate enum cases from the list.
125
126  Arguments:
127   - lst: List whose elements are to be uniqued. Assumes each element is a
128     (symbol, value) pair and elements already sorted according to value.
129
130  Returns:
131   - A list with all duplicates removed. The elements are sorted according to
132     value and, for each value, uniqued according to symbol.
133     original list,
134   - A map from deduplicated cases to the uniqued case.
135  """
136  cases = lst
137  uniqued_cases = []
138  duplicated_cases = {}
139
140  # First sort according to the value
141  cases.sort(key=lambda x: x[1])
142
143  # Then group them according to the value
144  for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
145    # For each value, sort according to the enumerant symbol.
146    sorted_group = sorted(groups, key=lambda x: x[0])
147    # Keep the "smallest" case, which is typically the symbol without extension
148    # suffix. But we have special cases that we want to fix.
149    case = sorted_group[0]
150    for i in range(1, len(sorted_group)):
151      duplicated_cases[sorted_group[i][0]] = case[0]
152    if case[0] == 'HlslSemanticGOOGLE':
153      assert len(sorted_group) == 2, 'unexpected new variant for HlslSemantic'
154      case = sorted_group[1]
155      duplicated_cases[sorted_group[0][0]] = case[0]
156    uniqued_cases.append(case)
157
158  return uniqued_cases, duplicated_cases
159
160
161def toposort(dag, sort_fn):
162  """Topologically sorts the given dag.
163
164  Arguments:
165    - dag: a dict mapping from a node to its incoming nodes.
166    - sort_fn: a function for sorting nodes in the same batch.
167
168  Returns:
169    A list containing topologically sorted nodes.
170  """
171
172  # Returns the next batch of nodes without incoming edges
173  def get_next_batch(dag):
174    while True:
175      no_prev_nodes = set(node for node, prev in dag.items() if not prev)
176      if not no_prev_nodes:
177        break
178      yield sorted(no_prev_nodes, key=sort_fn)
179      dag = {
180          node: (prev - no_prev_nodes)
181          for node, prev in dag.items()
182          if node not in no_prev_nodes
183      }
184    assert not dag, 'found cyclic dependency'
185
186  sorted_nodes = []
187  for batch in get_next_batch(dag):
188    sorted_nodes.extend(batch)
189
190  return sorted_nodes
191
192
193def toposort_capabilities(all_cases, capability_mapping):
194  """Returns topologically sorted capability (symbol, value) pairs.
195
196  Arguments:
197    - all_cases: all capability cases (containing symbol, value, and implied
198      capabilities).
199    - capability_mapping: mapping from duplicated capability symbols to the
200      canonicalized symbol chosen for SPIRVBase.td.
201
202  Returns:
203    A list containing topologically sorted capability (symbol, value) pairs.
204  """
205  dag = {}
206  name_to_value = {}
207  for case in all_cases:
208    # Get the current capability.
209    cur = case['enumerant']
210    name_to_value[cur] = case['value']
211    # Ignore duplicated symbols.
212    if cur in capability_mapping:
213      continue
214
215    # Get capabilities implied by the current capability.
216    prev = case.get('capabilities', [])
217    uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
218    dag[cur] = uniqued_prev
219
220  sorted_caps = toposort(dag, lambda x: name_to_value[x])
221  # Attach the capability's value as the second component of the pair.
222  return [(c, name_to_value[c]) for c in sorted_caps]
223
224
225def get_capability_mapping(operand_kinds):
226  """Returns the capability mapping from duplicated cases to canonicalized ones.
227
228  Arguments:
229    - operand_kinds: all operand kinds' grammar spec
230
231  Returns:
232    - A map mapping from duplicated capability symbols to the canonicalized
233      symbol chosen for SPIRVBase.td.
234  """
235  # Find the operand kind for capability
236  cap_kind = {}
237  for kind in operand_kinds:
238    if kind['kind'] == 'Capability':
239      cap_kind = kind
240
241  kind_cases = [
242      (case['enumerant'], case['value']) for case in cap_kind['enumerants']
243  ]
244  _, capability_mapping = uniquify_enum_cases(kind_cases)
245
246  return capability_mapping
247
248
249def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
250  """Returns the availability specification string for the given enum case.
251
252  Arguments:
253    - enum_case: the enum case to generate availability spec for. It may contain
254      'version', 'lastVersion', 'extensions', or 'capabilities'.
255    - capability_mapping: mapping from duplicated capability symbols to the
256      canonicalized symbol chosen for SPIRVBase.td.
257    - for_op: bool value indicating whether this is the availability spec for an
258      op itself.
259    - for_cap: bool value indicating whether this is the availability spec for
260      capabilities themselves.
261
262  Returns:
263    - A `let availability = [...];` string if with availability spec or
264      empty string if without availability spec
265  """
266  assert not (for_op and for_cap), 'cannot set both for_op and for_cap'
267
268  DEFAULT_MIN_VERSION = 'MinVersion<SPV_V_1_0>'
269  DEFAULT_MAX_VERSION = 'MaxVersion<SPV_V_1_5>'
270  DEFAULT_CAP = 'Capability<[]>'
271  DEFAULT_EXT = 'Extension<[]>'
272
273  min_version = enum_case.get('version', '')
274  if min_version == 'None':
275    min_version = ''
276  elif min_version:
277    min_version = 'MinVersion<SPV_V_{}>'.format(min_version.replace('.', '_'))
278  # TODO: delete this once ODS can support dialect-specific content
279  # and we can use omission to mean no requirements.
280  if for_op and not min_version:
281    min_version = DEFAULT_MIN_VERSION
282
283  max_version = enum_case.get('lastVersion', '')
284  if max_version:
285    max_version = 'MaxVersion<SPV_V_{}>'.format(max_version.replace('.', '_'))
286  # TODO: delete this once ODS can support dialect-specific content
287  # and we can use omission to mean no requirements.
288  if for_op and not max_version:
289    max_version = DEFAULT_MAX_VERSION
290
291  exts = enum_case.get('extensions', [])
292  if exts:
293    exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts))))
294    # We need to strip the minimal version requirement if this symbol is
295    # available via an extension, which means *any* SPIR-V version can support
296    # it as long as the extension is provided. The grammar's 'version' field
297    # under such case should be interpreted as this symbol is introduced as
298    # a core symbol since the given version, rather than a minimal version
299    # requirement.
300    min_version = DEFAULT_MIN_VERSION if for_op else ''
301  # TODO: delete this once ODS can support dialect-specific content
302  # and we can use omission to mean no requirements.
303  if for_op and not exts:
304    exts = DEFAULT_EXT
305
306  caps = enum_case.get('capabilities', [])
307  implies = ''
308  if caps:
309    canonicalized_caps = []
310    for c in caps:
311      if c in capability_mapping:
312        canonicalized_caps.append(capability_mapping[c])
313      else:
314        canonicalized_caps.append(c)
315    prefixed_caps = [
316        'SPV_C_{}'.format(c) for c in sorted(set(canonicalized_caps))
317    ]
318    if for_cap:
319      # If this is generating the availability for capabilities, we need to
320      # put the capability "requirements" in implies field because now
321      # the "capabilities" field in the source grammar means so.
322      caps = ''
323      implies = 'list<I32EnumAttrCase> implies = [{}];'.format(
324          ', '.join(prefixed_caps))
325    else:
326      caps = 'Capability<[{}]>'.format(', '.join(prefixed_caps))
327      implies = ''
328  # TODO: delete this once ODS can support dialect-specific content
329  # and we can use omission to mean no requirements.
330  if for_op and not caps:
331    caps = DEFAULT_CAP
332
333  avail = ''
334  # Compose availability spec if any of the requirements is not empty.
335  # For ops, because we have a default in SPV_Op class, omit if the spec
336  # is the same.
337  if (min_version or max_version or caps or exts) and not (
338      for_op and min_version == DEFAULT_MIN_VERSION and
339      max_version == DEFAULT_MAX_VERSION and caps == DEFAULT_CAP and
340      exts == DEFAULT_EXT):
341    joined_spec = ',\n    '.join(
342        [e for e in [min_version, max_version, exts, caps] if e])
343    avail = '{} availability = [\n    {}\n  ];'.format(
344        'let' if for_op else 'list<Availability>', joined_spec)
345
346  return '{}{}{}'.format(implies, '\n  ' if implies and avail else '', avail)
347
348
349def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
350  """Generates the TableGen EnumAttr definition for the given operand kind.
351
352  Returns:
353    - The operand kind's name
354    - A string containing the TableGen EnumAttr definition
355  """
356  if 'enumerants' not in operand_kind:
357    return '', ''
358
359  # Returns a symbol for the given case in the given kind. This function
360  # handles Dim specially to avoid having numbers as the start of symbols,
361  # which does not play well with C++ and the MLIR parser.
362  def get_case_symbol(kind_name, case_name):
363    if kind_name == 'Dim':
364      if case_name == '1D' or case_name == '2D' or case_name == '3D':
365        return 'Dim{}'.format(case_name)
366    return case_name
367
368  kind_name = operand_kind['kind']
369  is_bit_enum = operand_kind['category'] == 'BitEnum'
370  kind_category = 'Bit' if is_bit_enum else 'I32'
371  kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])
372
373  name_to_case_dict = {}
374  for case in operand_kind['enumerants']:
375    name_to_case_dict[case['enumerant']] = case
376
377  if kind_name == 'Capability':
378    # Special treatment for capability cases: we need to sort them topologically
379    # because a capability can refer to another via the 'implies' field.
380    kind_cases = toposort_capabilities(operand_kind['enumerants'],
381                                       capability_mapping)
382  else:
383    kind_cases = [(case['enumerant'], case['value'])
384                  for case in operand_kind['enumerants']]
385    kind_cases, _ = uniquify_enum_cases(kind_cases)
386  max_len = max([len(symbol) for (symbol, _) in kind_cases])
387
388  # Generate the definition for each enum case
389  fmt_str = 'def SPV_{acronym}_{case} {colon:>{offset}} '\
390            '{category}EnumAttrCase<"{symbol}", {value}>{avail}'
391  case_defs = []
392  for case in kind_cases:
393    avail = get_availability_spec(name_to_case_dict[case[0]],
394                                  capability_mapping,
395                                  False, kind_name == 'Capability')
396    case_def = fmt_str.format(
397        category=kind_category,
398        acronym=kind_acronym,
399        case=case[0],
400        symbol=get_case_symbol(kind_name, case[0]),
401        value=case[1],
402        avail=' {{\n  {}\n}}'.format(avail) if avail else ';',
403        colon=':',
404        offset=(max_len + 1 - len(case[0])))
405    case_defs.append(case_def)
406  case_defs = '\n'.join(case_defs)
407
408  # Generate the list of enum case names
409  fmt_str = 'SPV_{acronym}_{symbol}';
410  case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0])
411                for case in kind_cases]
412
413  # Split them into sublists and concatenate into multiple lines
414  case_names = split_list_into_sublists(case_names)
415  case_names = ['{:6}'.format('') + ', '.join(sublist)
416                for sublist in case_names]
417  case_names = ',\n'.join(case_names)
418
419  # Generate the enum attribute definition
420  enum_attr = '''def SPV_{name}Attr :
421    SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", [
422{cases}
423    ]>;'''.format(
424          name=kind_name, category=kind_category, cases=case_names)
425  return kind_name, case_defs + '\n\n' + enum_attr
426
427
428def gen_opcode(instructions):
429  """ Generates the TableGen definition to map opname to opcode
430
431  Returns:
432    - A string containing the TableGen SPV_OpCode definition
433  """
434
435  max_len = max([len(inst['opname']) for inst in instructions])
436  def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\
437            'I32EnumAttrCase<"{name}", {value}>;'
438  opcode_defs = [
439      def_fmt_str.format(
440          name=inst['opname'],
441          value=inst['opcode'],
442          colon=':',
443          offset=(max_len + 1 - len(inst['opname']))) for inst in instructions
444  ]
445  opcode_str = '\n'.join(opcode_defs)
446
447  decl_fmt_str = 'SPV_OC_{name}'
448  opcode_list = [
449      decl_fmt_str.format(name=inst['opname']) for inst in instructions
450  ]
451  opcode_list = split_list_into_sublists(opcode_list)
452  opcode_list = [
453      '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list
454  ]
455  opcode_list = ',\n'.join(opcode_list)
456  enum_attr = 'def SPV_OpcodeAttr :\n'\
457              '    SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\
458              '{lst}\n'\
459              '    ]>;'.format(name='Opcode', lst=opcode_list)
460  return opcode_str + '\n\n' + enum_attr
461
462def map_cap_to_opnames(instructions):
463  """Maps capabilities to instructions enabled by those capabilities
464
465  Arguments:
466    - instructions: a list containing a subset of SPIR-V instructions' grammar
467  Returns:
468    - A map with keys representing capabilities and values of lists of
469    instructions enabled by the corresponding key
470  """
471  cap_to_inst = {}
472
473  for inst in instructions:
474    caps = inst['capabilities'] if 'capabilities' in inst else ['0_core_0']
475    for cap in caps:
476      if cap not in cap_to_inst:
477        cap_to_inst[cap] = []
478      cap_to_inst[cap].append(inst['opname'])
479
480  return cap_to_inst
481
482def gen_instr_coverage_report(path, instructions):
483  """Dumps to standard output a YAML report of current instruction coverage
484
485  Arguments:
486    - path: the path to SPIRBase.td
487    - instructions: a list containing all SPIR-V instructions' grammar
488  """
489  with open(path, 'r') as f:
490    content = f.read()
491
492  content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
493
494  existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])]
495  existing_instructions = list(
496          filter(lambda inst: (inst['opname'] in existing_opcodes),
497              instructions))
498
499  instructions_opnames = [inst['opname'] for inst in instructions]
500
501  remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
502  remaining_instructions = list(
503          filter(lambda inst: (inst['opname'] in remaining_opcodes),
504              instructions))
505
506  rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
507  ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
508
509  rem_cap_to_cov = {}
510
511  # Calculate coverage for each capability
512  for cap in rem_cap_to_instr:
513    if cap not in ex_cap_to_instr:
514      rem_cap_to_cov[cap] = 0.0
515    else:
516      rem_cap_to_cov[cap] = \
517              (len(ex_cap_to_instr[cap]) / (len(ex_cap_to_instr[cap]) \
518              + len(rem_cap_to_instr[cap])))
519
520  report = {}
521
522  # Merge the 3 maps into one report
523  for cap in rem_cap_to_instr:
524    report[cap] = {}
525    report[cap]['Supported Instructions'] = \
526            ex_cap_to_instr[cap] if cap in ex_cap_to_instr else []
527    report[cap]['Unsupported Instructions']  = rem_cap_to_instr[cap]
528    report[cap]['Coverage'] = '{}%'.format(int(rem_cap_to_cov[cap] * 100))
529
530  print(yaml.dump(report))
531
532def update_td_opcodes(path, instructions, filter_list):
533  """Updates SPIRBase.td with new generated opcode cases.
534
535  Arguments:
536    - path: the path to SPIRBase.td
537    - instructions: a list containing all SPIR-V instructions' grammar
538    - filter_list: a list containing new opnames to add
539  """
540
541  with open(path, 'r') as f:
542    content = f.read()
543
544  content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
545  assert len(content) == 3
546
547  # Extend opcode list with existing list
548  existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])]
549  filter_list.extend(existing_opcodes)
550  filter_list = list(set(filter_list))
551
552  # Generate the opcode for all instructions in SPIR-V
553  filter_instrs = list(
554      filter(lambda inst: (inst['opname'] in filter_list), instructions))
555  # Sort instruction based on opcode
556  filter_instrs.sort(key=lambda inst: inst['opcode'])
557  opcode = gen_opcode(filter_instrs)
558
559  # Substitute the opcode
560  content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \
561        opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \
562        + content[2]
563
564  with open(path, 'w') as f:
565    f.write(content)
566
567
568def update_td_enum_attrs(path, operand_kinds, filter_list):
569  """Updates SPIRBase.td with new generated enum definitions.
570
571  Arguments:
572    - path: the path to SPIRBase.td
573    - operand_kinds: a list containing all operand kinds' grammar
574    - filter_list: a list containing new enums to add
575  """
576  with open(path, 'r') as f:
577    content = f.read()
578
579  content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
580  assert len(content) == 3
581
582  # Extend filter list with existing enum definitions
583  existing_kinds = [
584      k[8:-4] for k in re.findall('def SPV_\w+Attr', content[1])]
585  filter_list.extend(existing_kinds)
586
587  capability_mapping = get_capability_mapping(operand_kinds)
588
589  # Generate definitions for all enums in filter list
590  defs = [
591      gen_operand_kind_enum_attr(kind, capability_mapping)
592      for kind in operand_kinds
593      if kind['kind'] in filter_list
594  ]
595  # Sort alphabetically according to enum name
596  defs.sort(key=lambda enum : enum[0])
597  # Only keep the definitions from now on
598  # Put Capability's definition at the very beginning because capability cases
599  # will be referenced later
600  defs = [enum[1] for enum in defs if enum[0] == 'Capability'
601         ] + [enum[1] for enum in defs if enum[0] != 'Capability']
602
603  # Substitute the old section
604  content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \
605      '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER  \
606      + content[2];
607
608  with open(path, 'w') as f:
609    f.write(content)
610
611
612def snake_casify(name):
613  """Turns the given name to follow snake_case convention."""
614  name = re.sub('\W+', '', name).split()
615  name = [s.lower() for s in name]
616  return '_'.join(name)
617
618
619def map_spec_operand_to_ods_argument(operand):
620  """Maps an operand in SPIR-V JSON spec to an op argument in ODS.
621
622  Arguments:
623    - A dict containing the operand's kind, quantifier, and name
624
625  Returns:
626    - A string containing both the type and name for the argument
627  """
628  kind = operand['kind']
629  quantifier = operand.get('quantifier', '')
630
631  # These instruction "operands" are for encoding the results; they should
632  # not be handled here.
633  assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind'
634  assert kind != 'IdResult', 'unexpected to handle "IdResult" kind'
635
636  if kind == 'IdRef':
637    if quantifier == '':
638      arg_type = 'SPV_Type'
639    elif quantifier == '?':
640      arg_type = 'Optional<SPV_Type>'
641    else:
642      arg_type = 'Variadic<SPV_Type>'
643  elif kind == 'IdMemorySemantics' or kind == 'IdScope':
644    # TODO: Need to further constrain 'IdMemorySemantics'
645    # and 'IdScope' given that they should be generated from OpConstant.
646    assert quantifier == '', ('unexpected to have optional/variadic memory '
647                              'semantics or scope <id>')
648    arg_type = 'SPV_' + kind[2:] + 'Attr'
649  elif kind == 'LiteralInteger':
650    if quantifier == '':
651      arg_type = 'I32Attr'
652    elif quantifier == '?':
653      arg_type = 'OptionalAttr<I32Attr>'
654    else:
655      arg_type = 'OptionalAttr<I32ArrayAttr>'
656  elif kind == 'LiteralString' or \
657      kind == 'LiteralContextDependentNumber' or \
658      kind == 'LiteralExtInstInteger' or \
659      kind == 'LiteralSpecConstantOpInteger' or \
660      kind == 'PairLiteralIntegerIdRef' or \
661      kind == 'PairIdRefLiteralInteger' or \
662      kind == 'PairIdRefIdRef':
663    assert False, '"{}" kind unimplemented'.format(kind)
664  else:
665    # The rest are all enum operands that we represent with op attributes.
666    assert quantifier != '*', 'unexpected to have variadic enum attribute'
667    arg_type = 'SPV_{}Attr'.format(kind)
668    if quantifier == '?':
669      arg_type = 'OptionalAttr<{}>'.format(arg_type)
670
671  name = operand.get('name', '')
672  name = snake_casify(name) if name else kind.lower()
673
674  return '{}:${}'.format(arg_type, name)
675
676
677def get_description(text, appendix):
678  """Generates the description for the given SPIR-V instruction.
679
680  Arguments:
681    - text: Textual description of the operation as string.
682    - appendix: Additional contents to attach in description as string,
683                includking IR examples, and others.
684
685  Returns:
686    - A string that corresponds to the description of the Tablegen op.
687  """
688  fmt_str = '{text}\n\n    <!-- End of AutoGen section -->\n{appendix}\n  '
689  return fmt_str.format(text=text, appendix=appendix)
690
691
692def get_op_definition(instruction, opname, doc, existing_info, capability_mapping, settings):
693  """Generates the TableGen op definition for the given SPIR-V instruction.
694
695  Arguments:
696    - instruction: the instruction's SPIR-V JSON grammar
697    - doc: the instruction's SPIR-V HTML doc
698    - existing_info: a dict containing potential manually specified sections for
699      this instruction
700    - capability_mapping: mapping from duplicated capability symbols to the
701                   canonicalized symbol chosen for SPIRVBase.td
702
703  Returns:
704    - A string containing the TableGen op definition
705  """
706  if settings.gen_cl_ops:
707    fmt_str = ('def SPV_{opname}Op : '
708               'SPV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > '
709               '{{\n  let summary = {summary};\n\n  let description = '
710               '[{{\n{description}}}];{availability}\n')
711  else:
712    fmt_str = ('def SPV_{opname_src}Op : '
713               'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> '
714               '{{\n  let summary = {summary};\n\n  let description = '
715               '[{{\n{description}}}];{availability}\n')
716
717  inst_category = existing_info.get('inst_category', 'Op')
718  if inst_category == 'Op':
719    fmt_str +='\n  let arguments = (ins{args});\n\n'\
720              '  let results = (outs{results});\n'
721
722  fmt_str +='{extras}'\
723            '}}\n'
724
725  opname_src = instruction['opname']
726  if opname.startswith('Op'):
727    opname_src = opname_src[2:]
728
729  category_args = existing_info.get('category_args', '')
730
731  if '\n' in doc:
732    summary, text = doc.split('\n', 1)
733  else:
734    summary = doc
735    text = ''
736  wrapper = textwrap.TextWrapper(
737      width=76, initial_indent='    ', subsequent_indent='    ')
738
739  # Format summary. If the summary can fit in the same line, we print it out
740  # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
741  summary = summary.strip();
742  if len(summary) + len('  let summary = "";') <= 80:
743    summary = '"{}"'.format(summary)
744  else:
745    summary = '[{{\n{}\n  }}]'.format(wrapper.fill(summary))
746
747  # Wrap text
748  text = text.split('\n')
749  text = [wrapper.fill(line) for line in text if line]
750  text = '\n\n'.join(text)
751
752  operands = instruction.get('operands', [])
753
754  # Op availability
755  avail = get_availability_spec(instruction, capability_mapping, True, False)
756  if avail:
757    avail = '\n\n  {0}'.format(avail)
758
759  # Set op's result
760  results = ''
761  if len(operands) > 0 and operands[0]['kind'] == 'IdResultType':
762    results = '\n    SPV_Type:$result\n  '
763    operands = operands[1:]
764  if 'results' in existing_info:
765    results = existing_info['results']
766
767  # Ignore the operand standing for the result <id>
768  if len(operands) > 0 and operands[0]['kind'] == 'IdResult':
769    operands = operands[1:]
770
771  # Set op' argument
772  arguments = existing_info.get('arguments', None)
773  if arguments is None:
774    arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
775    arguments = ',\n    '.join(arguments)
776    if arguments:
777      # Prepend and append whitespace for formatting
778      arguments = '\n    {}\n  '.format(arguments)
779
780  description = existing_info.get('description', None)
781  if description is None:
782    assembly = '\n    ```\n'\
783               '    [TODO]\n'\
784               '    ```mlir\n\n'\
785               '    #### Example:\n\n'\
786               '    ```\n'\
787               '    [TODO]\n' \
788               '    ```'
789    description = get_description(text, assembly)
790
791  return fmt_str.format(
792      opname=opname,
793      opname_src=opname_src,
794      opcode=instruction['opcode'],
795      category_args=category_args,
796      inst_category=inst_category,
797      traits=existing_info.get('traits', ''),
798      summary=summary,
799      description=description,
800      availability=avail,
801      args=arguments,
802      results=results,
803      extras=existing_info.get('extras', ''))
804
805
806def get_string_between(base, start, end):
807  """Extracts a substring with a specified start and end from a string.
808
809  Arguments:
810    - base: string to extract from.
811    - start: string to use as the start of the substring.
812    - end: string to use as the end of the substring.
813
814  Returns:
815    - The substring if found
816    - The part of the base after end of the substring. Is the base string itself
817      if the substring wasnt found.
818  """
819  split = base.split(start, 1)
820  if len(split) == 2:
821    rest = split[1].split(end, 1)
822    assert len(rest) == 2, \
823           'cannot find end "{end}" while extracting substring '\
824           'starting with {start}'.format(start=start, end=end)
825    return rest[0].rstrip(end), rest[1]
826  return '', split[0]
827
828
829def get_string_between_nested(base, start, end):
830  """Extracts a substring with a nested start and end from a string.
831
832  Arguments:
833    - base: string to extract from.
834    - start: string to use as the start of the substring.
835    - end: string to use as the end of the substring.
836
837  Returns:
838    - The substring if found
839    - The part of the base after end of the substring. Is the base string itself
840      if the substring wasn't found.
841  """
842  split = base.split(start, 1)
843  if len(split) == 2:
844    # Handle nesting delimiters
845    rest = split[1]
846    unmatched_start = 1
847    index = 0
848    while unmatched_start > 0 and index < len(rest):
849      if rest[index:].startswith(end):
850        unmatched_start -= 1
851        if unmatched_start == 0:
852          break
853        index += len(end)
854      elif rest[index:].startswith(start):
855        unmatched_start += 1
856        index += len(start)
857      else:
858        index += 1
859
860    assert index < len(rest), \
861           'cannot find end "{end}" while extracting substring '\
862           'starting with "{start}"'.format(start=start, end=end)
863    return rest[:index], rest[index + len(end):]
864  return '', split[0]
865
866
867def extract_td_op_info(op_def):
868  """Extracts potentially manually specified sections in op's definition.
869
870  Arguments: - A string containing the op's TableGen definition
871
872  Returns:
873    - A dict containing potential manually specified sections
874  """
875  # Get opname
876  opname = [o[8:-2] for o in re.findall('def SPV_\w+Op', op_def)]
877  assert len(opname) == 1, 'more than one ops in the same section!'
878  opname = opname[0]
879
880  # Get instruction category
881  inst_category = [
882      o[4:] for o in re.findall('SPV_\w+Op',
883                                op_def.split(':', 1)[1])
884  ]
885  assert len(inst_category) <= 1, 'more than one ops in the same section!'
886  inst_category = inst_category[0] if len(inst_category) == 1 else 'Op'
887
888  # Get category_args
889  op_tmpl_params, _ = get_string_between_nested(op_def, '<', '>')
890  opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
891  category_args = rest.split('[', 1)[0]
892
893  # Get traits
894  traits, _ = get_string_between_nested(rest, '[', ']')
895
896  # Get description
897  description, rest = get_string_between(op_def, 'let description = [{\n',
898                                         '}];\n')
899
900  # Get arguments
901  args, rest = get_string_between(rest, '  let arguments = (ins', ');\n')
902
903  # Get results
904  results, rest = get_string_between(rest, '  let results = (outs', ');\n')
905
906  extras = rest.strip(' }\n')
907  if extras:
908    extras = '\n  {}\n'.format(extras)
909
910  return {
911      # Prefix with 'Op' to make it consistent with SPIR-V spec
912      'opname': 'Op{}'.format(opname),
913      'inst_category': inst_category,
914      'category_args': category_args,
915      'traits': traits,
916      'description': description,
917      'arguments': args,
918      'results': results,
919      'extras': extras
920  }
921
922
923def update_td_op_definitions(path, instructions, docs, filter_list,
924                             inst_category, capability_mapping, settings):
925  """Updates SPIRVOps.td with newly generated op definition.
926
927  Arguments:
928    - path: path to SPIRVOps.td
929    - instructions: SPIR-V JSON grammar for all instructions
930    - docs: SPIR-V HTML doc for all instructions
931    - filter_list: a list containing new opnames to include
932    - capability_mapping: mapping from duplicated capability symbols to the
933                   canonicalized symbol chosen for SPIRVBase.td.
934
935  Returns:
936    - A string containing all the TableGen op definitions
937  """
938  with open(path, 'r') as f:
939    content = f.read()
940
941  # Split the file into chunks, each containing one op.
942  ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
943  header = ops[0]
944  footer = ops[-1]
945  ops = ops[1:-1]
946
947  # For each existing op, extract the manually-written sections out to retain
948  # them when re-generating the ops. Also append the existing ops to filter
949  # list.
950  name_op_map = {}  # Map from opname to its existing ODS definition
951  op_info_dict = {}
952  for op in ops:
953    info_dict = extract_td_op_info(op)
954    opname = info_dict['opname']
955    name_op_map[opname] = op
956    op_info_dict[opname] = info_dict
957    filter_list.append(opname)
958  filter_list = sorted(list(set(filter_list)))
959
960  op_defs = []
961
962  if settings.gen_cl_ops:
963    fix_opname = lambda src: src.replace('CL','').lower()
964  else:
965    fix_opname = lambda src: src
966
967  for opname in filter_list:
968    # Find the grammar spec for this op
969    try:
970      fixed_opname = fix_opname(opname)
971      instruction = next(
972          inst for inst in instructions if inst['opname'] == fixed_opname)
973
974      op_defs.append(
975          get_op_definition(
976              instruction, opname, docs[fixed_opname],
977              op_info_dict.get(opname, {'inst_category': inst_category}),
978              capability_mapping, settings))
979    except StopIteration:
980      # This is an op added by us; use the existing ODS definition.
981      op_defs.append(name_op_map[opname])
982
983  # Substitute the old op definitions
984  op_defs = [header] + op_defs + [footer]
985  content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
986
987  with open(path, 'w') as f:
988    f.write(content)
989
990
991if __name__ == '__main__':
992  import argparse
993
994  cli_parser = argparse.ArgumentParser(
995      description='Update SPIR-V dialect definitions using SPIR-V spec')
996
997  cli_parser.add_argument(
998      '--base-td-path',
999      dest='base_td_path',
1000      type=str,
1001      default=None,
1002      help='Path to SPIRVBase.td')
1003  cli_parser.add_argument(
1004      '--op-td-path',
1005      dest='op_td_path',
1006      type=str,
1007      default=None,
1008      help='Path to SPIRVOps.td')
1009
1010  cli_parser.add_argument(
1011      '--new-enum',
1012      dest='new_enum',
1013      type=str,
1014      default=None,
1015      help='SPIR-V enum to be added to SPIRVBase.td')
1016  cli_parser.add_argument(
1017      '--new-opcodes',
1018      dest='new_opcodes',
1019      type=str,
1020      default=None,
1021      nargs='*',
1022      help='update SPIR-V opcodes in SPIRVBase.td')
1023  cli_parser.add_argument(
1024      '--new-inst',
1025      dest='new_inst',
1026      type=str,
1027      default=None,
1028      nargs='*',
1029      help='SPIR-V instruction to be added to ops file')
1030  cli_parser.add_argument(
1031      '--inst-category',
1032      dest='inst_category',
1033      type=str,
1034      default='Op',
1035      help='SPIR-V instruction category used for choosing '\
1036           'the TableGen base class to define this op')
1037  cli_parser.add_argument(
1038      '--gen-cl-ops',
1039      dest='gen_cl_ops',
1040      help='Generate OpenCL Extended Instruction Set op',
1041      action='store_true')
1042  cli_parser.set_defaults(gen_cl_ops=False)
1043  cli_parser.add_argument('--gen-inst-coverage', dest='gen_inst_coverage', action='store_true')
1044  cli_parser.set_defaults(gen_inst_coverage=False)
1045
1046  args = cli_parser.parse_args()
1047
1048  if args.gen_cl_ops:
1049    ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL
1050    ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL
1051  else:
1052    ext_html_url = None
1053    ext_json_url = None
1054
1055  operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url)
1056
1057  # Define new enum attr
1058  if args.new_enum is not None:
1059    assert args.base_td_path is not None
1060    filter_list = [args.new_enum] if args.new_enum else []
1061    update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
1062
1063  # Define new opcode
1064  if args.new_opcodes is not None:
1065    assert args.base_td_path is not None
1066    update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
1067
1068  # Define new op
1069  if args.new_inst is not None:
1070    assert args.op_td_path is not None
1071    docs = get_spirv_doc_from_html_spec(ext_html_url, args)
1072    capability_mapping = get_capability_mapping(operand_kinds)
1073    update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst,
1074                             args.inst_category, capability_mapping, args)
1075    print('Done. Note that this script just generates a template; ', end='')
1076    print('please read the spec and update traits, arguments, and ', end='')
1077    print('results accordingly.')
1078
1079  if args.gen_inst_coverage:
1080    gen_instr_coverage_report(args.base_td_path, instructions)
1081