1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5try: 6 from ..ir import * 7 from ._ods_common import get_op_result_or_value as _get_op_result_or_value 8 from ..dialects import pdl 9except ImportError as e: 10 raise RuntimeError("Error loading imports from extension module") from e 11 12from typing import List, Optional, Sequence, Union 13 14IntOrAttrList = Sequence[Union[IntegerAttr, int]] 15OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] 16 17 18def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr: 19 if isinstance(value, int): 20 return IntegerAttr.get(IntegerType.get_signless(64), value) 21 return value 22 23 24def _get_array_attr( 25 values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr: 26 """Creates an array attribute from its operand.""" 27 if values is None: 28 return ArrayAttr.get([]) 29 if isinstance(values, ArrayAttr): 30 return values 31 32 return ArrayAttr.get(values) 33 34 35def _get_int_array_attr( 36 values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]] 37) -> ArrayAttr: 38 """Creates an integer array attribute from its operand. 39 40 If the operand is already an array attribute, forwards it. Otherwise treats 41 the operand as a list of attributes or integers, possibly intersperced, to 42 create a new array attribute containing integer attributes. Expects the 43 thread-local MLIR context to have been set by the context manager. 44 """ 45 if values is None: 46 return ArrayAttr.get([]) 47 if isinstance(values, ArrayAttr): 48 return values 49 50 return ArrayAttr.get([_get_int64_attr(v) for v in values]) 51 52 53def _get_int_int_array_attr( 54 values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, 55 IntOrAttrList]]]] 56) -> ArrayAttr: 57 """Creates an array attribute containing array attributes of integers. 58 59 If the operand is already an array attribute, forwards it. Otherwise treats 60 the operand as a list of attributes or integers, potentially interpserced, to 61 create a new array-of-array attribute. Expects the thread-local MLIR context 62 to have been set by the context manager. 63 """ 64 if values is None: 65 return ArrayAttr.get([]) 66 if isinstance(values, ArrayAttr): 67 return values 68 69 return ArrayAttr.get([_get_int_array_attr(value) for value in values]) 70 71 72class DecomposeOp: 73 """Specialization for DecomposeOp class.""" 74 75 def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): 76 super().__init__( 77 pdl.OperationType.get(), 78 _get_op_result_or_value(target), 79 loc=loc, 80 ip=ip) 81 82 83class GeneralizeOp: 84 """Specialization for GeneralizeOp class.""" 85 86 def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): 87 super().__init__( 88 pdl.OperationType.get(), 89 _get_op_result_or_value(target), 90 loc=loc, 91 ip=ip) 92 93 94class InterchangeOp: 95 """Specialization for InterchangeOp class.""" 96 97 def __init__(self, 98 target: Union[Operation, Value], 99 *, 100 iterator_interchange: OptionalIntList = None, 101 loc=None, 102 ip=None): 103 pdl_operation_type = pdl.OperationType.get() 104 interchange_attr = _get_int_array_attr(iterator_interchange) 105 super().__init__( 106 pdl_operation_type, 107 _get_op_result_or_value(target), 108 iterator_interchange=interchange_attr, 109 loc=loc, 110 ip=ip) 111 112 113class MultiTileSizesOp: 114 """Specialization for MultitileSizesOp class.""" 115 116 def __init__(self, 117 target: Union[Operation, Value], 118 *, 119 dimension: Union[int, IntegerAttr], 120 target_size: Union[int, IntegerAttr], 121 divisor: Optional[Union[int, IntegerAttr]] = None, 122 loc=None, 123 ip=None): 124 super().__init__( 125 pdl.OperationType.get(), 126 pdl.OperationType.get(), 127 pdl.OperationType.get(), 128 _get_op_result_or_value(target), 129 dimension=_get_int64_attr(dimension), 130 target_size=_get_int64_attr(target_size), 131 divisor=_get_int64_attr(divisor if divisor else 1), 132 loc=loc, 133 ip=ip) 134 135 136class PadOp: 137 """Specialization for PadOp class.""" 138 139 def __init__(self, 140 target: Union[Operation, Value], 141 *, 142 padding_values: Optional[Union[ArrayAttr, 143 Sequence[Attribute]]] = None, 144 padding_dimensions: OptionalIntList = None, 145 pack_paddings: OptionalIntList = None, 146 hoist_paddings: OptionalIntList = None, 147 transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ 148 ArrayAttr, IntOrAttrList]]]] = None, 149 loc=None, 150 ip=None): 151 pdl_operation_type = pdl.OperationType.get() 152 padding_values_attr = _get_array_attr(padding_values) 153 padding_dimensions_attr = _get_int_array_attr(padding_dimensions) 154 pack_paddings_attr = _get_int_array_attr(pack_paddings) 155 hoist_paddings_attr = _get_int_array_attr(hoist_paddings) 156 transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) 157 super().__init__( 158 pdl_operation_type, 159 _get_op_result_or_value(target), 160 padding_values=padding_values_attr, 161 padding_dimensions=padding_dimensions_attr, 162 pack_paddings=pack_paddings_attr, 163 hoist_paddings=hoist_paddings_attr, 164 transpose_paddings=transpose_paddings_attr, 165 loc=loc, 166 ip=ip) 167 168 169class ScalarizeOp: 170 """Specialization for ScalarizeOp class.""" 171 172 def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): 173 pdl_operation_type = pdl.OperationType.get() 174 super().__init__( 175 pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip) 176 177 178class SplitOp: 179 """Specialization for SplitOp class.""" 180 181 def __init__(self, 182 target: Union[Operation, Value], 183 dimension: Union[int, Attribute], 184 split_point: Union[int, Operation, Value, Attribute], 185 *, 186 loc=None, 187 ip=None): 188 dimension = _get_int64_attr(dimension) 189 if isinstance(split_point, int): 190 split_point = _get_int64_attr(split_point) 191 192 if isinstance(split_point, Attribute): 193 static_split_point = split_point 194 dynamic_split_point = None 195 else: 196 static_split_point = _get_int64_attr(ShapedType._get_dynamic_size()) 197 dynamic_split_point = _get_op_result_or_value(split_point) 198 199 pdl_operation_type = pdl.OperationType.get() 200 super().__init__( 201 pdl_operation_type, 202 pdl_operation_type, 203 _get_op_result_or_value(target), 204 dimension=dimension, 205 static_split_point=static_split_point, 206 dynamic_split_point=dynamic_split_point, 207 loc=loc, 208 ip=ip) 209 210 211class TileOp: 212 """Specialization for TileOp class.""" 213 214 def __init__(self, 215 target: Union[Operation, Value], 216 *, 217 sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, 218 Value]], ArrayAttr]] = None, 219 interchange: OptionalIntList = None, 220 loc=None, 221 ip=None): 222 pdl_operation_type = pdl.OperationType.get() 223 i64_type = IntegerType.get_signless(64) 224 225 if sizes is None: 226 sizes = [] 227 228 static_sizes = [] 229 dynamic_sizes = [] 230 if isinstance(sizes, ArrayAttr): 231 sizes_attr = sizes 232 else: 233 for size in sizes: 234 if isinstance(size, int): 235 static_sizes.append(IntegerAttr.get(i64_type, size)) 236 elif isinstance(size, IntegerAttr): 237 static_sizes.append(size) 238 else: 239 static_sizes.append( 240 IntegerAttr.get(i64_type, ShapedType._get_dynamic_size())) 241 dynamic_sizes.append(_get_op_result_or_value(size)) 242 sizes_attr = ArrayAttr.get(static_sizes) 243 244 num_loops = sum( 245 v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) 246 super().__init__( 247 pdl_operation_type, [pdl_operation_type] * num_loops, 248 _get_op_result_or_value(target), 249 dynamic_sizes=dynamic_sizes, 250 static_sizes=sizes_attr, 251 interchange=_get_int_array_attr(interchange) if interchange else None, 252 loc=loc, 253 ip=ip) 254 255 def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]: 256 if not attr: 257 return [] 258 return [IntegerAttr(element).value for element in attr] 259 260 261class VectorizeOp: 262 """Specialization for VectorizeOp class.""" 263 264 def __init__(self, 265 target: Union[Operation, Value], 266 *, 267 vectorize_padding: Union[bool, BoolAttr] = False, 268 loc=None, 269 ip=None): 270 pdl_operation_type = pdl.OperationType.get() 271 if isinstance(vectorize_padding, bool): 272 vectorize_padding = BoolAttr.get(vectorize_padding) 273 super().__init__( 274 pdl_operation_type, 275 _get_op_result_or_value(target), 276 vectorize_padding=vectorize_padding, 277 loc=loc, 278 ip=ip) 279