1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from ..ir import *
7  from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
8  from ..dialects import pdl
9except ImportError as e:
10  raise RuntimeError("Error loading imports from extension module") from e
11
12from typing import Optional, overload, Sequence, Union
13
14
15def _get_symbol_ref_attr(value: Union[Attribute, str]):
16  if isinstance(value, Attribute):
17    return value
18  return FlatSymbolRefAttr.get(value)
19
20
21class GetClosestIsolatedParentOp:
22
23  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
24    super().__init__(
25        pdl.OperationType.get(),
26        _get_op_result_or_value(target),
27        loc=loc,
28        ip=ip)
29
30
31class MergeHandlesOp:
32
33  def __init__(self,
34               handles: Sequence[Union[Operation, Value]],
35               *,
36               deduplicate: bool = False,
37               loc=None,
38               ip=None):
39    super().__init__(
40        pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles],
41        deduplicate=deduplicate,
42        loc=loc,
43        ip=ip)
44
45
46class PDLMatchOp:
47
48  def __init__(self,
49               target: Union[Operation, Value],
50               pattern_name: Union[Attribute, str],
51               *,
52               loc=None,
53               ip=None):
54    super().__init__(
55        pdl.OperationType.get(),
56        _get_op_result_or_value(target),
57        _get_symbol_ref_attr(pattern_name),
58        loc=loc,
59        ip=ip)
60
61
62class ReplicateOp:
63
64  def __init__(self,
65               pattern: Union[Operation, Value],
66               handles: Sequence[Union[Operation, Value]],
67               *,
68               loc=None,
69               ip=None):
70    super().__init__(
71        [pdl.OperationType.get()] * len(handles),
72        _get_op_result_or_value(pattern),
73        [_get_op_result_or_value(h) for h in handles],
74        loc=loc,
75        ip=ip)
76
77
78class SequenceOp:
79
80  @overload
81  def __init__(self, resultsOrRoot: Sequence[Type],
82               optionalRoot: Optional[Union[Operation, Value]]):
83    ...
84
85  @overload
86  def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]],
87               optionalRoot: NoneType):
88    ...
89
90  def __init__(self, resultsOrRoot=None, optionalRoot=None):
91    results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else []
92    root = (
93        resultsOrRoot
94        if not isinstance(resultsOrRoot, Sequence) else optionalRoot)
95    root = _get_op_result_or_value(root) if root else None
96    super().__init__(results_=results, root=root)
97    self.regions[0].blocks.append(pdl.OperationType.get())
98
99  @property
100  def body(self) -> Block:
101    return self.regions[0].blocks[0]
102
103  @property
104  def bodyTarget(self) -> Value:
105    return self.body.arguments[0]
106
107
108class WithPDLPatternsOp:
109
110  def __init__(self,
111               target: Optional[Union[Operation, Value]] = None,
112               *,
113               loc=None,
114               ip=None):
115    super().__init__(
116        root=_get_op_result_or_value(target) if target else None,
117        loc=loc,
118        ip=ip)
119    self.regions[0].blocks.append(pdl.OperationType.get())
120
121  @property
122  def body(self) -> Block:
123    return self.regions[0].blocks[0]
124
125  @property
126  def bodyTarget(self) -> Value:
127    return self.body.arguments[0]
128
129
130class YieldOp:
131
132  def __init__(self,
133               operands: Union[Operation, Sequence[Value]] = [],
134               *,
135               loc=None,
136               ip=None):
137    super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
138