1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from ..ir import *
7  from ..dialects import pdl
8except ImportError as e:
9  raise RuntimeError("Error loading imports from extension module") from e
10
11from typing import Union, Optional, Sequence, List, Mapping
12from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values
13
14
15def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr:
16  """Converts the given value to signless integer attribute of given bit width."""
17  if isinstance(value, int):
18    ty = IntegerType.get_signless(bits)
19    return IntegerAttr.get(ty, value)
20  else:
21    return value
22
23
24def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr:
25  """Converts the given value to array attribute."""
26  if isinstance(attrs, ArrayAttr):
27    return attrs
28  else:
29    return ArrayAttr.get(list(attrs))
30
31
32def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr:
33  """Converts the given value to string array attribute."""
34  if isinstance(attrs, ArrayAttr):
35    return attrs
36  else:
37    return ArrayAttr.get([StringAttr.get(s) for s in attrs])
38
39
40def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]:
41  """Converts the given value to string attribute."""
42  if isinstance(name, str):
43    return StringAttr.get(name)
44  else:
45    return name
46
47
48def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr:
49  """Converts the given value to type attribute."""
50  if isinstance(type, Type):
51    return TypeAttr.get(type)
52  else:
53    return type
54
55
56class ApplyNativeConstraintOp:
57  """Specialization for PDL apply native constraint op class."""
58
59  def __init__(self,
60               name: Union[str, StringAttr],
61               args: Sequence[Union[OpView, Operation, Value]] = [],
62               *,
63               loc=None,
64               ip=None):
65    name = _get_str_attr(name)
66    args = _get_values(args)
67    super().__init__(name, args, loc=loc, ip=ip)
68
69
70class ApplyNativeRewriteOp:
71  """Specialization for PDL apply native rewrite op class."""
72
73  def __init__(self,
74               results: Sequence[Type],
75               name: Union[str, StringAttr],
76               args: Sequence[Union[OpView, Operation, Value]] = [],
77               *,
78               loc=None,
79               ip=None):
80    name = _get_str_attr(name)
81    args = _get_values(args)
82    super().__init__(results, name, args, loc=loc, ip=ip)
83
84
85class AttributeOp:
86  """Specialization for PDL attribute op class."""
87
88  def __init__(self,
89               type: Optional[Union[OpView, Operation, Value]] = None,
90               value: Optional[Attribute] = None,
91               *,
92               loc=None,
93               ip=None):
94    type = type if type is None else _get_value(type)
95    result = pdl.AttributeType.get()
96    super().__init__(result, type=type, value=value, loc=loc, ip=ip)
97
98
99class EraseOp:
100  """Specialization for PDL erase op class."""
101
102  def __init__(self,
103               operation: Optional[Union[OpView, Operation, Value]] = None,
104               *,
105               loc=None,
106               ip=None):
107    operation = _get_value(operation)
108    super().__init__(operation, loc=loc, ip=ip)
109
110
111class OperandOp:
112  """Specialization for PDL operand op class."""
113
114  def __init__(self,
115               type: Optional[Union[OpView, Operation, Value]] = None,
116               *,
117               loc=None,
118               ip=None):
119    type = type if type is None else _get_value(type)
120    result = pdl.ValueType.get()
121    super().__init__(result, type=type, loc=loc, ip=ip)
122
123
124class OperandsOp:
125  """Specialization for PDL operands op class."""
126
127  def __init__(self,
128               types: Optional[Union[OpView, Operation, Value]] = None,
129               *,
130               loc=None,
131               ip=None):
132    types = types if types is None else _get_value(types)
133    result = pdl.RangeType.get(pdl.ValueType.get())
134    super().__init__(result, type=types, loc=loc, ip=ip)
135
136
137class OperationOp:
138  """Specialization for PDL operand op class."""
139
140  def __init__(self,
141               name: Optional[Union[str, StringAttr]] = None,
142               args: Sequence[Union[OpView, Operation, Value]] = [],
143               attributes: Mapping[str, Union[OpView, Operation, Value]] = {},
144               types: Sequence[Union[OpView, Operation, Value]] = [],
145               *,
146               loc=None,
147               ip=None):
148    name = name if name is None else _get_str_attr(name)
149    args = _get_values(args)
150    attributeNames = []
151    attributeValues = []
152    for attrName, attrValue in attributes.items():
153      attributeNames.append(StringAttr.get(attrName))
154      attributeValues.append(_get_value(attrValue))
155    attributeNames = ArrayAttr.get(attributeNames)
156    types = _get_values(types)
157    result = pdl.OperationType.get()
158    super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip)
159
160
161class PatternOp:
162  """Specialization for PDL pattern op class."""
163
164  def __init__(self,
165               benefit: Union[IntegerAttr, int],
166               name: Optional[Union[StringAttr, str]] = None,
167               *,
168               loc=None,
169               ip=None):
170    """Creates an PDL `pattern` operation."""
171    name_attr = None if name is None else _get_str_attr(name)
172    benefit_attr = _get_int_attr(16, benefit)
173    super().__init__(benefit_attr, sym_name=name_attr, loc=loc, ip=ip)
174    self.regions[0].blocks.append()
175
176  @property
177  def body(self):
178    """Return the body (block) of the pattern."""
179    return self.regions[0].blocks[0]
180
181
182class ReplaceOp:
183  """Specialization for PDL replace op class."""
184
185  def __init__(self,
186               op: Union[OpView, Operation, Value],
187               *,
188               with_op: Optional[Union[OpView, Operation, Value]] = None,
189               with_values: Sequence[Union[OpView, Operation, Value]] = [],
190               loc=None,
191               ip=None):
192    op = _get_value(op)
193    with_op = with_op if with_op is None else _get_value(with_op)
194    with_values = _get_values(with_values)
195    super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
196
197
198class ResultOp:
199  """Specialization for PDL result op class."""
200
201  def __init__(self,
202               parent: Union[OpView, Operation, Value],
203               index: Union[IntegerAttr, int],
204               *,
205               loc=None,
206               ip=None):
207    index = _get_int_attr(32, index)
208    parent = _get_value(parent)
209    result = pdl.ValueType.get()
210    super().__init__(result, parent, index, loc=loc, ip=ip)
211
212
213class ResultsOp:
214  """Specialization for PDL results op class."""
215
216  def __init__(self,
217               result: Type,
218               parent: Union[OpView, Operation, Value],
219               index: Optional[Union[IntegerAttr, int]] = None,
220               *,
221               loc=None,
222               ip=None):
223    parent = _get_value(parent)
224    index = index if index is None else _get_int_attr(32, index)
225    super().__init__(result, parent, index=index, loc=loc, ip=ip)
226
227
228class RewriteOp:
229  """Specialization for PDL rewrite op class."""
230
231  def __init__(self,
232               root: Optional[Union[OpView, Operation, Value]] = None,
233               name: Optional[Union[StringAttr, str]] = None,
234               args: Sequence[Union[OpView, Operation, Value]] = [],
235               *,
236               loc=None,
237               ip=None):
238    root = root if root is None else _get_value(root)
239    name = name if name is None else _get_str_attr(name)
240    args = _get_values(args)
241    super().__init__(args, root=root,name=name, loc=loc, ip=ip)
242
243  def add_body(self):
244    """Add body (block) to the rewrite."""
245    self.regions[0].blocks.append()
246    return self.body
247
248  @property
249  def body(self):
250    """Return the body (block) of the rewrite."""
251    return self.regions[0].blocks[0]
252
253
254class TypeOp:
255  """Specialization for PDL type op class."""
256
257  def __init__(self,
258               type: Optional[Union[TypeAttr, Type]] = None,
259               *,
260               loc=None,
261               ip=None):
262    type = type if type is None else _get_type_attr(type)
263    result = pdl.TypeType.get()
264    super().__init__(result, type=type, loc=loc, ip=ip)
265
266
267class TypesOp:
268  """Specialization for PDL types op class."""
269
270  def __init__(self,
271               types: Sequence[Union[TypeAttr, Type]] = [],
272               *,
273               loc=None,
274               ip=None):
275    types = _get_array_attr([_get_type_attr(ty) for ty in types])
276    types = None if not types else types
277    result = pdl.RangeType.get(pdl.TypeType.get())
278    super().__init__(result, types=types, loc=loc, ip=ip)
279