1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from ..ir import *
7  from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
8  from ..dialects import pdl
9except ImportError as e:
10  raise RuntimeError("Error loading imports from extension module") from e
11
12from typing import Optional, overload, Sequence, Union
13
14
15def _get_symbol_ref_attr(value: Union[Attribute, str]):
16  if isinstance(value, Attribute):
17    return value
18  return FlatSymbolRefAttr.get(value)
19
20
21class GetClosestIsolatedParentOp:
22
23  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
24    super().__init__(
25        pdl.OperationType.get(),
26        _get_op_result_or_value(target),
27        loc=loc,
28        ip=ip)
29
30
31class PDLMatchOp:
32
33  def __init__(self,
34               target: Union[Operation, Value],
35               pattern_name: Union[Attribute, str],
36               *,
37               loc=None,
38               ip=None):
39    super().__init__(
40        pdl.OperationType.get(),
41        _get_op_result_or_value(target),
42        _get_symbol_ref_attr(pattern_name),
43        loc=loc,
44        ip=ip)
45
46
47class SequenceOp:
48
49  @overload
50  def __init__(self, resultsOrRoot: Sequence[Type],
51               optionalRoot: Optional[Union[Operation, Value]]):
52    ...
53
54  @overload
55  def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]],
56               optionalRoot: NoneType):
57    ...
58
59  def __init__(self, resultsOrRoot=None, optionalRoot=None):
60    results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else []
61    root = (
62        resultsOrRoot
63        if not isinstance(resultsOrRoot, Sequence) else optionalRoot)
64    root = _get_op_result_or_value(root) if root else None
65    super().__init__(results_=results, root=root)
66    self.regions[0].blocks.append(pdl.OperationType.get())
67
68  @property
69  def body(self) -> Block:
70    return self.regions[0].blocks[0]
71
72  @property
73  def bodyTarget(self) -> Value:
74    return self.body.arguments[0]
75
76
77class WithPDLPatternsOp:
78
79  def __init__(self,
80               target: Optional[Union[Operation, Value]] = None,
81               *,
82               loc=None,
83               ip=None):
84    super().__init__(
85        root=_get_op_result_or_value(target) if target else None,
86        loc=loc,
87        ip=ip)
88    self.regions[0].blocks.append(pdl.OperationType.get())
89
90  @property
91  def body(self) -> Block:
92    return self.regions[0].blocks[0]
93
94  @property
95  def bodyTarget(self) -> Value:
96    return self.body.arguments[0]
97
98
99class YieldOp:
100
101  def __init__(self,
102               operands: Union[Operation, Sequence[Value]] = [],
103               *,
104               loc=None,
105               ip=None):
106    super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
107