1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5try: 6 from ..ir import * 7except ImportError as e: 8 raise RuntimeError("Error loading imports from extension module") from e 9 10from typing import Any, Optional, Sequence, Union 11from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values 12 13class ForOp: 14 """Specialization for the SCF for op class.""" 15 16 def __init__(self, 17 lower_bound, 18 upper_bound, 19 step, 20 iter_args: Optional[Union[Operation, OpView, 21 Sequence[Value]]] = None, 22 *, 23 loc=None, 24 ip=None): 25 """Creates an SCF `for` operation. 26 27 - `lower_bound` is the value to use as lower bound of the loop. 28 - `upper_bound` is the value to use as upper bound of the loop. 29 - `step` is the value to use as loop step. 30 - `iter_args` is a list of additional loop-carried arguments or an operation 31 producing them as results. 32 """ 33 if iter_args is None: 34 iter_args = [] 35 iter_args = _get_op_results_or_values(iter_args) 36 37 results = [arg.type for arg in iter_args] 38 super().__init__( 39 self.build_generic( 40 regions=1, 41 results=results, 42 operands=[ 43 _get_op_result_or_value(o) 44 for o in [lower_bound, upper_bound, step] 45 ] + list(iter_args), 46 loc=loc, 47 ip=ip)) 48 self.regions[0].blocks.append(IndexType.get(), *results) 49 50 @property 51 def body(self): 52 """Returns the body (block) of the loop.""" 53 return self.regions[0].blocks[0] 54 55 @property 56 def induction_variable(self): 57 """Returns the induction variable of the loop.""" 58 return self.body.arguments[0] 59 60 @property 61 def inner_iter_args(self): 62 """Returns the loop-carried arguments usable within the loop. 63 64 To obtain the loop-carried operands, use `iter_args`. 65 """ 66 return self.body.arguments[1:] 67 68 69class IfOp: 70 """Specialization for the SCF if op class.""" 71 72 def __init__(self, 73 cond, 74 results_=[], 75 *, 76 hasElse=False, 77 loc=None, 78 ip=None): 79 """Creates an SCF `if` operation. 80 81 - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. 82 - `hasElse` determines whether the if operation has the else branch. 83 """ 84 operands = [] 85 operands.append(cond) 86 results = [] 87 results.extend(results_) 88 super().__init__( 89 self.build_generic( 90 regions=2, 91 results=results, 92 operands=operands, 93 loc=loc, 94 ip=ip)) 95 self.regions[0].blocks.append(*[]) 96 if hasElse: 97 self.regions[1].blocks.append(*[]) 98 99 @property 100 def then_block(self): 101 """Returns the then block of the if operation.""" 102 return self.regions[0].blocks[0] 103 104 @property 105 def else_block(self): 106 """Returns the else block of the if operation.""" 107 return self.regions[1].blocks[0] 108