1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5try: 6 from ..ir import * 7 from ..dialects import pdl 8except ImportError as e: 9 raise RuntimeError("Error loading imports from extension module") from e 10 11from typing import Union, Optional, Sequence, List, Mapping 12from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values 13 14 15def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr: 16 """Converts the given value to signless integer attribute of given bit width.""" 17 if isinstance(value, int): 18 ty = IntegerType.get_signless(bits) 19 return IntegerAttr.get(ty, value) 20 else: 21 return value 22 23 24def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr: 25 """Converts the given value to array attribute.""" 26 if isinstance(attrs, ArrayAttr): 27 return attrs 28 else: 29 return ArrayAttr.get(list(attrs)) 30 31 32def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr: 33 """Converts the given value to string array attribute.""" 34 if isinstance(attrs, ArrayAttr): 35 return attrs 36 else: 37 return ArrayAttr.get([StringAttr.get(s) for s in attrs]) 38 39 40def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]: 41 """Converts the given value to string attribute.""" 42 if isinstance(name, str): 43 return StringAttr.get(name) 44 else: 45 return name 46 47 48def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr: 49 """Converts the given value to type attribute.""" 50 if isinstance(type, Type): 51 return TypeAttr.get(type) 52 else: 53 return type 54 55 56class ApplyNativeConstraintOp: 57 """Specialization for PDL apply native constraint op class.""" 58 59 def __init__(self, 60 name: Union[str, StringAttr], 61 args: Sequence[Union[OpView, Operation, Value]] = [], 62 *, 63 loc=None, 64 ip=None): 65 name = _get_str_attr(name) 66 args = _get_values(args) 67 super().__init__(name, args, loc=loc, ip=ip) 68 69 70class ApplyNativeRewriteOp: 71 """Specialization for PDL apply native rewrite op class.""" 72 73 def __init__(self, 74 results: Sequence[Type], 75 name: Union[str, StringAttr], 76 args: Sequence[Union[OpView, Operation, Value]] = [], 77 *, 78 loc=None, 79 ip=None): 80 name = _get_str_attr(name) 81 args = _get_values(args) 82 super().__init__(results, name, args, loc=loc, ip=ip) 83 84 85class AttributeOp: 86 """Specialization for PDL attribute op class.""" 87 88 def __init__(self, 89 type: Optional[Union[OpView, Operation, Value]] = None, 90 value: Optional[Attribute] = None, 91 *, 92 loc=None, 93 ip=None): 94 type = type if type is None else _get_value(type) 95 result = pdl.AttributeType.get() 96 super().__init__(result, type=type, value=value, loc=loc, ip=ip) 97 98 99class EraseOp: 100 """Specialization for PDL erase op class.""" 101 102 def __init__(self, 103 operation: Optional[Union[OpView, Operation, Value]] = None, 104 *, 105 loc=None, 106 ip=None): 107 operation = _get_value(operation) 108 super().__init__(operation, loc=loc, ip=ip) 109 110 111class OperandOp: 112 """Specialization for PDL operand op class.""" 113 114 def __init__(self, 115 type: Optional[Union[OpView, Operation, Value]] = None, 116 *, 117 loc=None, 118 ip=None): 119 type = type if type is None else _get_value(type) 120 result = pdl.ValueType.get() 121 super().__init__(result, type=type, loc=loc, ip=ip) 122 123 124class OperandsOp: 125 """Specialization for PDL operands op class.""" 126 127 def __init__(self, 128 types: Optional[Union[OpView, Operation, Value]] = None, 129 *, 130 loc=None, 131 ip=None): 132 types = types if types is None else _get_value(types) 133 result = pdl.RangeType.get(pdl.ValueType.get()) 134 super().__init__(result, type=types, loc=loc, ip=ip) 135 136 137class OperationOp: 138 """Specialization for PDL operand op class.""" 139 140 def __init__(self, 141 name: Optional[Union[str, StringAttr]] = None, 142 args: Sequence[Union[OpView, Operation, Value]] = [], 143 attributes: Mapping[str, Union[OpView, Operation, Value]] = {}, 144 types: Sequence[Union[OpView, Operation, Value]] = [], 145 *, 146 loc=None, 147 ip=None): 148 name = name if name is None else _get_str_attr(name) 149 args = _get_values(args) 150 attributeNames = [] 151 attributeValues = [] 152 for attrName, attrValue in attributes.items(): 153 attributeNames.append(StringAttr.get(attrName)) 154 attributeValues.append(_get_value(attrValue)) 155 attributeNames = ArrayAttr.get(attributeNames) 156 types = _get_values(types) 157 result = pdl.OperationType.get() 158 super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip) 159 160 161class PatternOp: 162 """Specialization for PDL pattern op class.""" 163 164 def __init__(self, 165 benefit: Union[IntegerAttr, int], 166 name: Optional[Union[StringAttr, str]] = None, 167 *, 168 loc=None, 169 ip=None): 170 """Creates an PDL `pattern` operation.""" 171 name_attr = None if name is None else _get_str_attr(name) 172 benefit_attr = _get_int_attr(16, benefit) 173 super().__init__(benefit_attr, sym_name=name_attr, loc=loc, ip=ip) 174 self.regions[0].blocks.append() 175 176 @property 177 def body(self): 178 """Return the body (block) of the pattern.""" 179 return self.regions[0].blocks[0] 180 181 182class ReplaceOp: 183 """Specialization for PDL replace op class.""" 184 185 def __init__(self, 186 op: Union[OpView, Operation, Value], 187 *, 188 with_op: Optional[Union[OpView, Operation, Value]] = None, 189 with_values: Sequence[Union[OpView, Operation, Value]] = [], 190 loc=None, 191 ip=None): 192 op = _get_value(op) 193 with_op = with_op if with_op is None else _get_value(with_op) 194 with_values = _get_values(with_values) 195 super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) 196 197 198class ResultOp: 199 """Specialization for PDL result op class.""" 200 201 def __init__(self, 202 parent: Union[OpView, Operation, Value], 203 index: Union[IntegerAttr, int], 204 *, 205 loc=None, 206 ip=None): 207 index = _get_int_attr(32, index) 208 parent = _get_value(parent) 209 result = pdl.ValueType.get() 210 super().__init__(result, parent, index, loc=loc, ip=ip) 211 212 213class ResultsOp: 214 """Specialization for PDL results op class.""" 215 216 def __init__(self, 217 result: Type, 218 parent: Union[OpView, Operation, Value], 219 index: Optional[Union[IntegerAttr, int]] = None, 220 *, 221 loc=None, 222 ip=None): 223 parent = _get_value(parent) 224 index = index if index is None else _get_int_attr(32, index) 225 super().__init__(result, parent, index=index, loc=loc, ip=ip) 226 227 228class RewriteOp: 229 """Specialization for PDL rewrite op class.""" 230 231 def __init__(self, 232 root: Optional[Union[OpView, Operation, Value]] = None, 233 name: Optional[Union[StringAttr, str]] = None, 234 args: Sequence[Union[OpView, Operation, Value]] = [], 235 *, 236 loc=None, 237 ip=None): 238 root = root if root is None else _get_value(root) 239 name = name if name is None else _get_str_attr(name) 240 args = _get_values(args) 241 super().__init__(args, root=root,name=name, loc=loc, ip=ip) 242 243 def add_body(self): 244 """Add body (block) to the rewrite.""" 245 self.regions[0].blocks.append() 246 return self.body 247 248 @property 249 def body(self): 250 """Return the body (block) of the rewrite.""" 251 return self.regions[0].blocks[0] 252 253 254class TypeOp: 255 """Specialization for PDL type op class.""" 256 257 def __init__(self, 258 type: Optional[Union[TypeAttr, Type]] = None, 259 *, 260 loc=None, 261 ip=None): 262 type = type if type is None else _get_type_attr(type) 263 result = pdl.TypeType.get() 264 super().__init__(result, type=type, loc=loc, ip=ip) 265 266 267class TypesOp: 268 """Specialization for PDL types op class.""" 269 270 def __init__(self, 271 types: Sequence[Union[TypeAttr, Type]] = [], 272 *, 273 loc=None, 274 ip=None): 275 types = _get_array_attr([_get_type_attr(ty) for ty in types]) 276 types = None if not types else types 277 result = pdl.RangeType.get(pdl.TypeType.get()) 278 super().__init__(result, types=types, loc=loc, ip=ip) 279