1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5try: 6 from ..ir import * 7 from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values 8 from ..dialects import pdl 9except ImportError as e: 10 raise RuntimeError("Error loading imports from extension module") from e 11 12from typing import Optional, overload, Sequence, Union 13 14 15def _get_symbol_ref_attr(value: Union[Attribute, str]): 16 if isinstance(value, Attribute): 17 return value 18 return FlatSymbolRefAttr.get(value) 19 20 21class GetClosestIsolatedParentOp: 22 23 def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): 24 super().__init__( 25 pdl.OperationType.get(), 26 _get_op_result_or_value(target), 27 loc=loc, 28 ip=ip) 29 30 31class MergeHandlesOp: 32 33 def __init__(self, 34 handles: Sequence[Union[Operation, Value]], 35 *, 36 deduplicate: bool = False, 37 loc=None, 38 ip=None): 39 super().__init__( 40 pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles], 41 deduplicate=deduplicate, 42 loc=loc, 43 ip=ip) 44 45 46class PDLMatchOp: 47 48 def __init__(self, 49 target: Union[Operation, Value], 50 pattern_name: Union[Attribute, str], 51 *, 52 loc=None, 53 ip=None): 54 super().__init__( 55 pdl.OperationType.get(), 56 _get_op_result_or_value(target), 57 _get_symbol_ref_attr(pattern_name), 58 loc=loc, 59 ip=ip) 60 61 62class SequenceOp: 63 64 @overload 65 def __init__(self, resultsOrRoot: Sequence[Type], 66 optionalRoot: Optional[Union[Operation, Value]]): 67 ... 68 69 @overload 70 def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]], 71 optionalRoot: NoneType): 72 ... 73 74 def __init__(self, resultsOrRoot=None, optionalRoot=None): 75 results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else [] 76 root = ( 77 resultsOrRoot 78 if not isinstance(resultsOrRoot, Sequence) else optionalRoot) 79 root = _get_op_result_or_value(root) if root else None 80 super().__init__(results_=results, root=root) 81 self.regions[0].blocks.append(pdl.OperationType.get()) 82 83 @property 84 def body(self) -> Block: 85 return self.regions[0].blocks[0] 86 87 @property 88 def bodyTarget(self) -> Value: 89 return self.body.arguments[0] 90 91 92class WithPDLPatternsOp: 93 94 def __init__(self, 95 target: Optional[Union[Operation, Value]] = None, 96 *, 97 loc=None, 98 ip=None): 99 super().__init__( 100 root=_get_op_result_or_value(target) if target else None, 101 loc=loc, 102 ip=ip) 103 self.regions[0].blocks.append(pdl.OperationType.get()) 104 105 @property 106 def body(self) -> Block: 107 return self.regions[0].blocks[0] 108 109 @property 110 def bodyTarget(self) -> Value: 111 return self.body.arguments[0] 112 113 114class YieldOp: 115 116 def __init__(self, 117 operands: Union[Operation, Sequence[Value]] = [], 118 *, 119 loc=None, 120 ip=None): 121 super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) 122