1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from ..ir import *
7  from ._ods_common import get_default_loc_context as _get_default_loc_context
8
9  import inspect
10
11  from typing import Any, List, Optional, Sequence, Union
12except ImportError as e:
13  raise RuntimeError("Error loading imports from extension module") from e
14
15ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
16RESULT_ATTRIBUTE_NAME = "res_attrs"
17
18class ConstantOp:
19  """Specialization for the constant op class."""
20
21  def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
22    super().__init__(result, value, loc=loc, ip=ip)
23
24  @property
25  def type(self):
26    return self.results[0].type
27
28
29class FuncOp:
30  """Specialization for the func op class."""
31
32  def __init__(self,
33               name,
34               type,
35               *,
36               visibility=None,
37               body_builder=None,
38               loc=None,
39               ip=None):
40    """
41    Create a FuncOp with the provided `name`, `type`, and `visibility`.
42    - `name` is a string representing the function name.
43    - `type` is either a FunctionType or a pair of list describing inputs and
44      results.
45    - `visibility` is a string matching `public`, `private`, or `nested`. None
46      implies private visibility.
47    - `body_builder` is an optional callback, when provided a new entry block
48      is created and the callback is invoked with the new op as argument within
49      an InsertionPoint context already set for the block. The callback is
50      expected to insert a terminator in the block.
51    """
52    sym_name = StringAttr.get(str(name))
53
54    # If the type is passed as a tuple, build a FunctionType on the fly.
55    if isinstance(type, tuple):
56      type = FunctionType.get(inputs=type[0], results=type[1])
57
58    type = TypeAttr.get(type)
59    sym_visibility = StringAttr.get(
60        str(visibility)) if visibility is not None else None
61    super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
62    if body_builder:
63      entry_block = self.add_entry_block()
64      with InsertionPoint(entry_block):
65        body_builder(self)
66
67  @property
68  def is_external(self):
69    return len(self.regions[0].blocks) == 0
70
71  @property
72  def body(self):
73    return self.regions[0]
74
75  @property
76  def type(self):
77    return FunctionType(TypeAttr(self.attributes["function_type"]).value)
78
79  @property
80  def visibility(self):
81    return self.attributes["sym_visibility"]
82
83  @property
84  def name(self) -> StringAttr:
85    return StringAttr(self.attributes["sym_name"])
86
87  @property
88  def entry_block(self):
89    if self.is_external:
90      raise IndexError('External function does not have a body')
91    return self.regions[0].blocks[0]
92
93  def add_entry_block(self):
94    """
95    Add an entry block to the function body using the function signature to
96    infer block arguments.
97    Returns the newly created block
98    """
99    if not self.is_external:
100      raise IndexError('The function already has an entry block!')
101    self.body.blocks.append(*self.type.inputs)
102    return self.body.blocks[0]
103
104  @property
105  def arg_attrs(self):
106    return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
107
108  @arg_attrs.setter
109  def arg_attrs(self, attribute: Union[ArrayAttr, list]):
110    if isinstance(attribute, ArrayAttr):
111      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
112    else:
113      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
114          attribute, context=self.context)
115
116  @property
117  def arguments(self):
118    return self.entry_block.arguments
119
120  @property
121  def result_attrs(self):
122    return self.attributes[RESULT_ATTRIBUTE_NAME]
123
124  @result_attrs.setter
125  def result_attrs(self, attribute: ArrayAttr):
126    self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
127
128  @classmethod
129  def from_py_func(FuncOp,
130                   *inputs: Type,
131                   results: Optional[Sequence[Type]] = None,
132                   name: Optional[str] = None):
133    """Decorator to define an MLIR FuncOp specified as a python function.
134
135    Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
136    active for the current thread (i.e. established in a `with` block).
137
138    When applied as a decorator to a Python function, an entry block will
139    be constructed for the FuncOp with types as specified in `*inputs`. The
140    block arguments will be passed positionally to the Python function. In
141    addition, if the Python function accepts keyword arguments generally or
142    has a corresponding keyword argument, the following will be passed:
143      * `func_op`: The `func` op being defined.
144
145    By default, the function name will be the Python function `__name__`. This
146    can be overriden by passing the `name` argument to the decorator.
147
148    If `results` is not specified, then the decorator will implicitly
149    insert a `ReturnOp` with the `Value`'s returned from the decorated
150    function. It will also set the `FuncOp` type with the actual return
151    value types. If `results` is specified, then the decorated function
152    must return `None` and no implicit `ReturnOp` is added (nor are the result
153    types updated). The implicit behavior is intended for simple, single-block
154    cases, and users should specify result types explicitly for any complicated
155    cases.
156
157    The decorated function can further be called from Python and will insert
158    a `CallOp` at the then-current insertion point, returning either None (
159    if no return values), a unary Value (for one result), or a list of Values).
160    This mechanism cannot be used to emit recursive calls (by construction).
161    """
162
163    def decorator(f):
164      from . import func
165      # Introspect the callable for optional features.
166      sig = inspect.signature(f)
167      has_arg_func_op = False
168      for param in sig.parameters.values():
169        if param.kind == param.VAR_KEYWORD:
170          has_arg_func_op = True
171        if param.name == "func_op" and (param.kind
172                                        == param.POSITIONAL_OR_KEYWORD or
173                                        param.kind == param.KEYWORD_ONLY):
174          has_arg_func_op = True
175
176      # Emit the FuncOp.
177      implicit_return = results is None
178      symbol_name = name or f.__name__
179      function_type = FunctionType.get(
180          inputs=inputs, results=[] if implicit_return else results)
181      func_op = FuncOp(name=symbol_name, type=function_type)
182      with InsertionPoint(func_op.add_entry_block()):
183        func_args = func_op.entry_block.arguments
184        func_kwargs = {}
185        if has_arg_func_op:
186          func_kwargs["func_op"] = func_op
187        return_values = f(*func_args, **func_kwargs)
188        if not implicit_return:
189          return_types = list(results)
190          assert return_values is None, (
191              "Capturing a python function with explicit `results=` "
192              "requires that the wrapped function returns None.")
193        else:
194          # Coerce return values, add ReturnOp and rewrite func type.
195          if return_values is None:
196            return_values = []
197          elif isinstance(return_values, tuple):
198            return_values = list(return_values)
199          elif isinstance(return_values, Value):
200            # Returning a single value is fine, coerce it into a list.
201            return_values = [return_values]
202          elif isinstance(return_values, OpView):
203            # Returning a single operation is fine, coerce its results a list.
204            return_values = return_values.operation.results
205          elif isinstance(return_values, Operation):
206            # Returning a single operation is fine, coerce its results a list.
207            return_values = return_values.results
208          else:
209            return_values = list(return_values)
210          func.ReturnOp(return_values)
211          # Recompute the function type.
212          return_types = [v.type for v in return_values]
213          function_type = FunctionType.get(inputs=inputs, results=return_types)
214          func_op.attributes["function_type"] = TypeAttr.get(function_type)
215
216      def emit_call_op(*call_args):
217        call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
218                              call_args)
219        if return_types is None:
220          return None
221        elif len(return_types) == 1:
222          return call_op.result
223        else:
224          return call_op.results
225
226      wrapped = emit_call_op
227      wrapped.__name__ = f.__name__
228      wrapped.func_op = func_op
229      return wrapped
230
231    return decorator
232
233class CallOp:
234  """Specialization for the call op class."""
235
236  def __init__(self,
237               calleeOrResults: Union[FuncOp, List[Type]],
238               argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
239               arguments: Optional[List] = None,
240               *,
241               loc=None,
242               ip=None):
243    """Creates an call operation.
244
245    The constructor accepts three different forms:
246
247      1. A function op to be called followed by a list of arguments.
248      2. A list of result types, followed by the name of the function to be
249         called as string, following by a list of arguments.
250      3. A list of result types, followed by the name of the function to be
251         called as symbol reference attribute, followed by a list of arguments.
252
253    For example
254
255        f = func.FuncOp("foo", ...)
256        func.CallOp(f, [args])
257        func.CallOp([result_types], "foo", [args])
258
259    In all cases, the location and insertion point may be specified as keyword
260    arguments if not provided by the surrounding context managers.
261    """
262
263    # TODO: consider supporting constructor "overloads", e.g., through a custom
264    # or pybind-provided metaclass.
265    if isinstance(calleeOrResults, FuncOp):
266      if not isinstance(argumentsOrCallee, list):
267        raise ValueError(
268            "when constructing a call to a function, expected " +
269            "the second argument to be a list of call arguments, " +
270            f"got {type(argumentsOrCallee)}")
271      if arguments is not None:
272        raise ValueError("unexpected third argument when constructing a call" +
273                         "to a function")
274
275      super().__init__(
276          calleeOrResults.type.results,
277          FlatSymbolRefAttr.get(
278              calleeOrResults.name.value,
279              context=_get_default_loc_context(loc)),
280          argumentsOrCallee,
281          loc=loc,
282          ip=ip)
283      return
284
285    if isinstance(argumentsOrCallee, list):
286      raise ValueError("when constructing a call to a function by name, " +
287                       "expected the second argument to be a string or a " +
288                       f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
289
290    if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
291      super().__init__(
292          calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
293    elif isinstance(argumentsOrCallee, str):
294      super().__init__(
295          calleeOrResults,
296          FlatSymbolRefAttr.get(
297              argumentsOrCallee, context=_get_default_loc_context(loc)),
298          arguments,
299          loc=loc,
300          ip=ip)
301