1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from ..ir import *
7  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
8  from ..dialects import pdl
9except ImportError as e:
10  raise RuntimeError("Error loading imports from extension module") from e
11
12from typing import Optional, Union
13
14
15def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]],
16                    default_value: int = None):
17  if isinstance(arg, IntegerAttr):
18    return arg
19
20  if arg is None:
21    assert default_value is not None, "must provide default value"
22    arg = default_value
23
24  return IntegerAttr.get(IntegerType.get_signless(64), arg)
25
26
27class GetParentForOp:
28  """Extension for GetParentForOp."""
29
30  def __init__(self,
31               target: Union[Operation, Value],
32               *,
33               num_loops: int = 1,
34               ip=None,
35               loc=None):
36    super().__init__(
37        pdl.OperationType.get(),
38        _get_op_result_or_value(target),
39        num_loops=_get_int64_attr(num_loops, default_value=1),
40        ip=ip,
41        loc=loc)
42
43
44class LoopOutlineOp:
45  """Extension for LoopOutlineOp."""
46
47  def __init__(self,
48               target: Union[Operation, Value],
49               *,
50               func_name: Union[str, StringAttr],
51               ip=None,
52               loc=None):
53    super().__init__(
54        pdl.OperationType.get(),
55        _get_op_result_or_value(target),
56        func_name=(func_name if isinstance(func_name, StringAttr) else
57                   StringAttr.get(func_name)),
58        ip=ip,
59        loc=loc)
60
61
62class LoopPeelOp:
63  """Extension for LoopPeelOp."""
64
65  def __init__(self,
66               target: Union[Operation, Value],
67               *,
68               fail_if_already_divisible: Union[bool, BoolAttr] = False,
69               ip=None,
70               loc=None):
71    super().__init__(
72        pdl.OperationType.get(),
73        _get_op_result_or_value(target),
74        fail_if_already_divisible=(fail_if_already_divisible if isinstance(
75            fail_if_already_divisible, BoolAttr) else
76                                   BoolAttr.get(fail_if_already_divisible)),
77        ip=ip,
78        loc=loc)
79
80
81class LoopPipelineOp:
82  """Extension for LoopPipelineOp."""
83
84  def __init__(self,
85               target: Union[Operation, Value],
86               *,
87               iteration_interval: Optional[Union[int, IntegerAttr]] = None,
88               read_latency: Optional[Union[int, IntegerAttr]] = None,
89               ip=None,
90               loc=None):
91    super().__init__(
92        pdl.OperationType.get(),
93        _get_op_result_or_value(target),
94        iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
95        read_latency=_get_int64_attr(read_latency, default_value=10),
96        ip=ip,
97        loc=loc)
98
99
100class LoopUnrollOp:
101  """Extension for LoopUnrollOp."""
102
103  def __init__(self,
104               target: Union[Operation, Value],
105               *,
106               factor: Union[int, IntegerAttr],
107               ip=None,
108               loc=None):
109    super().__init__(
110        _get_op_result_or_value(target),
111        factor=_get_int64_attr(factor),
112        ip=ip,
113        loc=loc)
114