1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from ..ir import *
7  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
8  from ..dialects import pdl
9except ImportError as e:
10  raise RuntimeError("Error loading imports from extension module") from e
11
12from typing import List, Optional, Sequence, Union
13
14IntOrAttrList = Sequence[Union[IntegerAttr, int]]
15OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
16
17
18def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr:
19  if isinstance(value, int):
20    return IntegerAttr.get(IntegerType.get_signless(64), value)
21  return value
22
23
24def _get_array_attr(
25    values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
26  """Creates an array attribute from its operand."""
27  if values is None:
28    return ArrayAttr.get([])
29  if isinstance(values, ArrayAttr):
30    return values
31
32  return ArrayAttr.get(values)
33
34
35def _get_int_array_attr(
36    values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]]
37) -> ArrayAttr:
38  """Creates an integer array attribute from its operand.
39
40  If the operand is already an array attribute, forwards it. Otherwise treats
41  the operand as a list of attributes or integers, possibly intersperced, to
42  create a new array attribute containing integer attributes. Expects the
43  thread-local MLIR context to have been set by the context manager.
44  """
45  if values is None:
46    return ArrayAttr.get([])
47  if isinstance(values, ArrayAttr):
48    return values
49
50  return ArrayAttr.get([_get_int64_attr(v) for v in values])
51
52
53def _get_int_int_array_attr(
54    values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
55                                                     IntOrAttrList]]]]
56) -> ArrayAttr:
57  """Creates an array attribute containing array attributes of integers.
58
59  If the operand is already an array attribute, forwards it. Otherwise treats
60  the operand as a list of attributes or integers, potentially interpserced, to
61  create a new array-of-array attribute. Expects the thread-local MLIR context
62  to have been set by the context manager.
63  """
64  if values is None:
65    return ArrayAttr.get([])
66  if isinstance(values, ArrayAttr):
67    return values
68
69  return ArrayAttr.get([_get_int_array_attr(value) for value in values])
70
71
72class DecomposeOp:
73  """Specialization for DecomposeOp class."""
74
75  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
76    super().__init__(
77        pdl.OperationType.get(),
78        _get_op_result_or_value(target),
79        loc=loc,
80        ip=ip)
81
82
83class GeneralizeOp:
84  """Specialization for GeneralizeOp class."""
85
86  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
87    super().__init__(
88        pdl.OperationType.get(),
89        _get_op_result_or_value(target),
90        loc=loc,
91        ip=ip)
92
93
94class InterchangeOp:
95  """Specialization for InterchangeOp class."""
96
97  def __init__(self,
98               target: Union[Operation, Value],
99               *,
100               iterator_interchange: OptionalIntList = None,
101               loc=None,
102               ip=None):
103    pdl_operation_type = pdl.OperationType.get()
104    interchange_attr = _get_int_array_attr(iterator_interchange)
105    super().__init__(
106        pdl_operation_type,
107        _get_op_result_or_value(target),
108        iterator_interchange=interchange_attr,
109        loc=loc,
110        ip=ip)
111
112
113class MultiTileSizesOp:
114  """Specialization for MultitileSizesOp class."""
115
116  def __init__(self,
117               target: Union[Operation, Value],
118               *,
119               dimension: Union[int, IntegerAttr],
120               target_size: Union[int, IntegerAttr],
121               divisor: Optional[Union[int, IntegerAttr]] = None,
122               loc=None,
123               ip=None):
124    super().__init__(
125        pdl.OperationType.get(),
126        pdl.OperationType.get(),
127        pdl.OperationType.get(),
128        _get_op_result_or_value(target),
129        dimension=_get_int64_attr(dimension),
130        target_size=_get_int64_attr(target_size),
131        divisor=_get_int64_attr(divisor if divisor else 1),
132        loc=loc,
133        ip=ip)
134
135
136class PadOp:
137  """Specialization for PadOp class."""
138
139  def __init__(self,
140               target: Union[Operation, Value],
141               *,
142               padding_values: Optional[Union[ArrayAttr,
143                                              Sequence[Attribute]]] = None,
144               padding_dimensions: OptionalIntList = None,
145               pack_paddings: OptionalIntList = None,
146               hoist_paddings: OptionalIntList = None,
147               transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
148                   ArrayAttr, IntOrAttrList]]]] = None,
149               loc=None,
150               ip=None):
151    pdl_operation_type = pdl.OperationType.get()
152    padding_values_attr = _get_array_attr(padding_values)
153    padding_dimensions_attr = _get_int_array_attr(padding_dimensions)
154    pack_paddings_attr = _get_int_array_attr(pack_paddings)
155    hoist_paddings_attr = _get_int_array_attr(hoist_paddings)
156    transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
157    super().__init__(
158        pdl_operation_type,
159        _get_op_result_or_value(target),
160        padding_values=padding_values_attr,
161        padding_dimensions=padding_dimensions_attr,
162        pack_paddings=pack_paddings_attr,
163        hoist_paddings=hoist_paddings_attr,
164        transpose_paddings=transpose_paddings_attr,
165        loc=loc,
166        ip=ip)
167
168
169class ScalarizeOp:
170  """Specialization for ScalarizeOp class."""
171
172  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
173    pdl_operation_type = pdl.OperationType.get()
174    super().__init__(
175        pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)
176
177
178class SplitOp:
179  """Specialization for SplitOp class."""
180
181  def __init__(self,
182               target: Union[Operation, Value],
183               dimension: Union[int, Attribute],
184               split_point: Union[int, Operation, Value, Attribute],
185               *,
186               loc=None,
187               ip=None):
188    dimension = _get_int64_attr(dimension)
189    if isinstance(split_point, int):
190      split_point = _get_int64_attr(split_point)
191
192    if isinstance(split_point, Attribute):
193      static_split_point = split_point
194      dynamic_split_point = None
195    else:
196      static_split_point = _get_int64_attr(ShapedType._get_dynamic_size())
197      dynamic_split_point = _get_op_result_or_value(split_point)
198
199    pdl_operation_type = pdl.OperationType.get()
200    super().__init__(
201        pdl_operation_type,
202        pdl_operation_type,
203        _get_op_result_or_value(target),
204        dimension=dimension,
205        static_split_point=static_split_point,
206        dynamic_split_point=dynamic_split_point,
207        loc=loc,
208        ip=ip)
209
210
211class TileOp:
212  """Specialization for TileOp class."""
213
214  def __init__(self,
215               target: Union[Operation, Value],
216               *,
217               sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
218                                                    Value]], ArrayAttr]] = None,
219               interchange: OptionalIntList = None,
220               loc=None,
221               ip=None):
222    pdl_operation_type = pdl.OperationType.get()
223    i64_type = IntegerType.get_signless(64)
224
225    if sizes is None:
226      sizes = []
227
228    static_sizes = []
229    dynamic_sizes = []
230    if isinstance(sizes, ArrayAttr):
231      sizes_attr = sizes
232    else:
233      for size in sizes:
234        if isinstance(size, int):
235          static_sizes.append(IntegerAttr.get(i64_type, size))
236        elif isinstance(size, IntegerAttr):
237          static_sizes.append(size)
238        else:
239          static_sizes.append(
240              IntegerAttr.get(i64_type, ShapedType._get_dynamic_size()))
241          dynamic_sizes.append(_get_op_result_or_value(size))
242      sizes_attr = ArrayAttr.get(static_sizes)
243
244    num_loops = sum(
245        v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
246    super().__init__(
247        pdl_operation_type, [pdl_operation_type] * num_loops,
248        _get_op_result_or_value(target),
249        dynamic_sizes=dynamic_sizes,
250        static_sizes=sizes_attr,
251        interchange=_get_int_array_attr(interchange) if interchange else None,
252        loc=loc,
253        ip=ip)
254
255  def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]:
256    if not attr:
257      return []
258    return [IntegerAttr(element).value for element in attr]
259
260
261class VectorizeOp:
262  """Specialization for VectorizeOp class."""
263
264  def __init__(self,
265               target: Union[Operation, Value],
266               *,
267               vectorize_padding: Union[bool, BoolAttr] = False,
268               loc=None,
269               ip=None):
270    pdl_operation_type = pdl.OperationType.get()
271    if isinstance(vectorize_padding, bool):
272      vectorize_padding = BoolAttr.get(vectorize_padding)
273    super().__init__(
274        pdl_operation_type,
275        _get_op_result_or_value(target),
276        vectorize_padding=vectorize_padding,
277        loc=loc,
278        ip=ip)
279