1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from typing import Union
7  from ..ir import *
8  from ._ods_common import get_default_loc_context as _get_default_loc_context
9except ImportError as e:
10  raise RuntimeError("Error loading imports from extension module") from e
11
12from ._ml_program_ops_gen import *
13
14
15ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
16RESULT_ATTRIBUTE_NAME = "res_attrs"
17
18
19class FuncOp:
20  """Specialization for the func op class."""
21
22  def __init__(self,
23               name,
24               type,
25               *,
26               visibility=None,
27               body_builder=None,
28               loc=None,
29               ip=None):
30    """
31    Create a FuncOp with the provided `name`, `type`, and `visibility`.
32    - `name` is a string representing the function name.
33    - `type` is either a FunctionType or a pair of list describing inputs and
34      results.
35    - `visibility` is a string matching `public`, `private`, or `nested`. None
36      implies private visibility.
37    - `body_builder` is an optional callback, when provided a new entry block
38      is created and the callback is invoked with the new op as argument within
39      an InsertionPoint context already set for the block. The callback is
40      expected to insert a terminator in the block.
41    """
42    sym_name = StringAttr.get(str(name))
43
44    # If the type is passed as a tuple, build a FunctionType on the fly.
45    if isinstance(type, tuple):
46      type = FunctionType.get(inputs=type[0], results=type[1])
47
48    type = TypeAttr.get(type)
49    sym_visibility = StringAttr.get(
50        str(visibility)) if visibility is not None else None
51    super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
52    if body_builder:
53      entry_block = self.add_entry_block()
54      with InsertionPoint(entry_block):
55        body_builder(self)
56
57  @property
58  def is_external(self):
59    return len(self.regions[0].blocks) == 0
60
61  @property
62  def body(self):
63    return self.regions[0]
64
65  @property
66  def type(self):
67    return FunctionType(TypeAttr(self.attributes["function_type"]).value)
68
69  @property
70  def visibility(self):
71    return self.attributes["sym_visibility"]
72
73  @property
74  def name(self) -> StringAttr:
75    return StringAttr(self.attributes["sym_name"])
76
77  @property
78  def entry_block(self):
79    if self.is_external:
80      raise IndexError('External function does not have a body')
81    return self.regions[0].blocks[0]
82
83  def add_entry_block(self):
84    """
85    Add an entry block to the function body using the function signature to
86    infer block arguments.
87    Returns the newly created block
88    """
89    if not self.is_external:
90      raise IndexError('The function already has an entry block!')
91    self.body.blocks.append(*self.type.inputs)
92    return self.body.blocks[0]
93
94  @property
95  def arg_attrs(self):
96    return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
97
98  @arg_attrs.setter
99  def arg_attrs(self, attribute: Union[ArrayAttr, list]):
100    if isinstance(attribute, ArrayAttr):
101      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
102    else:
103      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
104          attribute, context=self.context)
105
106  @property
107  def arguments(self):
108    return self.entry_block.arguments
109
110  @property
111  def result_attrs(self):
112    return self.attributes[RESULT_ATTRIBUTE_NAME]
113
114  @result_attrs.setter
115  def result_attrs(self, attribute: ArrayAttr):
116    self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
117