1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5try: 6 from ..ir import * 7 from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values 8 from ..dialects import pdl 9except ImportError as e: 10 raise RuntimeError("Error loading imports from extension module") from e 11 12from typing import Optional, overload, Sequence, Union 13 14 15def _get_symbol_ref_attr(value: Union[Attribute, str]): 16 if isinstance(value, Attribute): 17 return value 18 return FlatSymbolRefAttr.get(value) 19 20 21class GetClosestIsolatedParentOp: 22 23 def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): 24 super().__init__( 25 pdl.OperationType.get(), 26 _get_op_result_or_value(target), 27 loc=loc, 28 ip=ip) 29 30 31class MergeHandlesOp: 32 33 def __init__(self, 34 handles: Sequence[Union[Operation, Value]], 35 *, 36 deduplicate: bool = False, 37 loc=None, 38 ip=None): 39 super().__init__( 40 pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles], 41 deduplicate=deduplicate, 42 loc=loc, 43 ip=ip) 44 45 46class PDLMatchOp: 47 48 def __init__(self, 49 target: Union[Operation, Value], 50 pattern_name: Union[Attribute, str], 51 *, 52 loc=None, 53 ip=None): 54 super().__init__( 55 pdl.OperationType.get(), 56 _get_op_result_or_value(target), 57 _get_symbol_ref_attr(pattern_name), 58 loc=loc, 59 ip=ip) 60 61 62class ReplicateOp: 63 64 def __init__(self, 65 pattern: Union[Operation, Value], 66 handles: Sequence[Union[Operation, Value]], 67 *, 68 loc=None, 69 ip=None): 70 super().__init__( 71 [pdl.OperationType.get()] * len(handles), 72 _get_op_result_or_value(pattern), 73 [_get_op_result_or_value(h) for h in handles], 74 loc=loc, 75 ip=ip) 76 77 78class SequenceOp: 79 80 @overload 81 def __init__(self, resultsOrRoot: Sequence[Type], 82 optionalRoot: Optional[Union[Operation, Value]]): 83 ... 84 85 @overload 86 def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]], 87 optionalRoot: NoneType): 88 ... 89 90 def __init__(self, resultsOrRoot=None, optionalRoot=None): 91 results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else [] 92 root = ( 93 resultsOrRoot 94 if not isinstance(resultsOrRoot, Sequence) else optionalRoot) 95 root = _get_op_result_or_value(root) if root else None 96 super().__init__(results_=results, root=root) 97 self.regions[0].blocks.append(pdl.OperationType.get()) 98 99 @property 100 def body(self) -> Block: 101 return self.regions[0].blocks[0] 102 103 @property 104 def bodyTarget(self) -> Value: 105 return self.body.arguments[0] 106 107 108class WithPDLPatternsOp: 109 110 def __init__(self, 111 target: Optional[Union[Operation, Value]] = None, 112 *, 113 loc=None, 114 ip=None): 115 super().__init__( 116 root=_get_op_result_or_value(target) if target else None, 117 loc=loc, 118 ip=ip) 119 self.regions[0].blocks.append(pdl.OperationType.get()) 120 121 @property 122 def body(self) -> Block: 123 return self.regions[0].blocks[0] 124 125 @property 126 def bodyTarget(self) -> Value: 127 return self.body.arguments[0] 128 129 130class YieldOp: 131 132 def __init__(self, 133 operands: Union[Operation, Sequence[Value]] = [], 134 *, 135 loc=None, 136 ip=None): 137 super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) 138