1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from ..ir import *
7  from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
8  from ..dialects import pdl
9except ImportError as e:
10  raise RuntimeError("Error loading imports from extension module") from e
11
12from typing import Optional, overload, Sequence, Union
13
14
15def _get_symbol_ref_attr(value: Union[Attribute, str]):
16  if isinstance(value, Attribute):
17    return value
18  return FlatSymbolRefAttr.get(value)
19
20
21class GetClosestIsolatedParentOp:
22
23  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
24    super().__init__(
25        pdl.OperationType.get(),
26        _get_op_result_or_value(target),
27        loc=loc,
28        ip=ip)
29
30
31class MergeHandlesOp:
32
33  def __init__(self,
34               handles: Sequence[Union[Operation, Value]],
35               *,
36               deduplicate: bool = False,
37               loc=None,
38               ip=None):
39    super().__init__(
40        pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles],
41        deduplicate=deduplicate,
42        loc=loc,
43        ip=ip)
44
45
46class PDLMatchOp:
47
48  def __init__(self,
49               target: Union[Operation, Value],
50               pattern_name: Union[Attribute, str],
51               *,
52               loc=None,
53               ip=None):
54    super().__init__(
55        pdl.OperationType.get(),
56        _get_op_result_or_value(target),
57        _get_symbol_ref_attr(pattern_name),
58        loc=loc,
59        ip=ip)
60
61
62class SequenceOp:
63
64  @overload
65  def __init__(self, resultsOrRoot: Sequence[Type],
66               optionalRoot: Optional[Union[Operation, Value]]):
67    ...
68
69  @overload
70  def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]],
71               optionalRoot: NoneType):
72    ...
73
74  def __init__(self, resultsOrRoot=None, optionalRoot=None):
75    results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else []
76    root = (
77        resultsOrRoot
78        if not isinstance(resultsOrRoot, Sequence) else optionalRoot)
79    root = _get_op_result_or_value(root) if root else None
80    super().__init__(results_=results, root=root)
81    self.regions[0].blocks.append(pdl.OperationType.get())
82
83  @property
84  def body(self) -> Block:
85    return self.regions[0].blocks[0]
86
87  @property
88  def bodyTarget(self) -> Value:
89    return self.body.arguments[0]
90
91
92class WithPDLPatternsOp:
93
94  def __init__(self,
95               target: Optional[Union[Operation, Value]] = None,
96               *,
97               loc=None,
98               ip=None):
99    super().__init__(
100        root=_get_op_result_or_value(target) if target else None,
101        loc=loc,
102        ip=ip)
103    self.regions[0].blocks.append(pdl.OperationType.get())
104
105  @property
106  def body(self) -> Block:
107    return self.regions[0].blocks[0]
108
109  @property
110  def bodyTarget(self) -> Value:
111    return self.body.arguments[0]
112
113
114class YieldOp:
115
116  def __init__(self,
117               operands: Union[Operation, Sequence[Value]] = [],
118               *,
119               loc=None,
120               ip=None):
121    super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
122