1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5try: 6 from ..ir import * 7 from ._ods_common import get_op_result_or_value as _get_op_result_or_value 8 from ..dialects import pdl 9except ImportError as e: 10 raise RuntimeError("Error loading imports from extension module") from e 11 12from typing import Optional, Union 13 14 15def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]], 16 default_value: int = None): 17 if isinstance(arg, IntegerAttr): 18 return arg 19 20 if arg is None: 21 assert default_value is not None, "must provide default value" 22 arg = default_value 23 24 return IntegerAttr.get(IntegerType.get_signless(64), arg) 25 26 27class GetParentForOp: 28 """Extension for GetParentForOp.""" 29 30 def __init__(self, 31 target: Union[Operation, Value], 32 *, 33 num_loops: int = 1, 34 ip=None, 35 loc=None): 36 super().__init__( 37 pdl.OperationType.get(), 38 _get_op_result_or_value(target), 39 num_loops=_get_int64_attr(num_loops, default_value=1), 40 ip=ip, 41 loc=loc) 42 43 44class LoopOutlineOp: 45 """Extension for LoopOutlineOp.""" 46 47 def __init__(self, 48 target: Union[Operation, Value], 49 *, 50 func_name: Union[str, StringAttr], 51 ip=None, 52 loc=None): 53 super().__init__( 54 pdl.OperationType.get(), 55 _get_op_result_or_value(target), 56 func_name=(func_name if isinstance(func_name, StringAttr) else 57 StringAttr.get(func_name)), 58 ip=ip, 59 loc=loc) 60 61 62class LoopPeelOp: 63 """Extension for LoopPeelOp.""" 64 65 def __init__(self, 66 target: Union[Operation, Value], 67 *, 68 fail_if_already_divisible: Union[bool, BoolAttr] = False, 69 ip=None, 70 loc=None): 71 super().__init__( 72 pdl.OperationType.get(), 73 _get_op_result_or_value(target), 74 fail_if_already_divisible=(fail_if_already_divisible if isinstance( 75 fail_if_already_divisible, BoolAttr) else 76 BoolAttr.get(fail_if_already_divisible)), 77 ip=ip, 78 loc=loc) 79 80 81class LoopPipelineOp: 82 """Extension for LoopPipelineOp.""" 83 84 def __init__(self, 85 target: Union[Operation, Value], 86 *, 87 iteration_interval: Optional[Union[int, IntegerAttr]] = None, 88 read_latency: Optional[Union[int, IntegerAttr]] = None, 89 ip=None, 90 loc=None): 91 super().__init__( 92 pdl.OperationType.get(), 93 _get_op_result_or_value(target), 94 iteration_interval=_get_int64_attr(iteration_interval, default_value=1), 95 read_latency=_get_int64_attr(read_latency, default_value=10), 96 ip=ip, 97 loc=loc) 98 99 100class LoopUnrollOp: 101 """Extension for LoopUnrollOp.""" 102 103 def __init__(self, 104 target: Union[Operation, Value], 105 *, 106 factor: Union[int, IntegerAttr], 107 ip=None, 108 loc=None): 109 super().__init__( 110 _get_op_result_or_value(target), 111 factor=_get_int64_attr(factor), 112 ip=ip, 113 loc=loc) 114