1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from ..ir import *
7except ImportError as e:
8  raise RuntimeError("Error loading imports from extension module") from e
9
10from typing import Any, Optional, Sequence, Union
11from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
12
13class ForOp:
14  """Specialization for the SCF for op class."""
15
16  def __init__(self,
17               lower_bound,
18               upper_bound,
19               step,
20               iter_args: Optional[Union[Operation, OpView,
21                                         Sequence[Value]]] = None,
22               *,
23               loc=None,
24               ip=None):
25    """Creates an SCF `for` operation.
26
27    - `lower_bound` is the value to use as lower bound of the loop.
28    - `upper_bound` is the value to use as upper bound of the loop.
29    - `step` is the value to use as loop step.
30    - `iter_args` is a list of additional loop-carried arguments or an operation
31      producing them as results.
32    """
33    if iter_args is None:
34      iter_args = []
35    iter_args = _get_op_results_or_values(iter_args)
36
37    results = [arg.type for arg in iter_args]
38    super().__init__(
39        self.build_generic(
40            regions=1,
41            results=results,
42            operands=[
43                _get_op_result_or_value(o)
44                for o in [lower_bound, upper_bound, step]
45            ] + list(iter_args),
46            loc=loc,
47            ip=ip))
48    self.regions[0].blocks.append(IndexType.get(), *results)
49
50  @property
51  def body(self):
52    """Returns the body (block) of the loop."""
53    return self.regions[0].blocks[0]
54
55  @property
56  def induction_variable(self):
57    """Returns the induction variable of the loop."""
58    return self.body.arguments[0]
59
60  @property
61  def inner_iter_args(self):
62    """Returns the loop-carried arguments usable within the loop.
63
64    To obtain the loop-carried operands, use `iter_args`.
65    """
66    return self.body.arguments[1:]
67
68
69class IfOp:
70  """Specialization for the SCF if op class."""
71
72  def __init__(self,
73               cond,
74               results_=[],
75               *,
76               hasElse=False,
77               loc=None,
78               ip=None):
79    """Creates an SCF `if` operation.
80
81    - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
82    - `hasElse` determines whether the if operation has the else branch.
83    """
84    operands = []
85    operands.append(cond)
86    results = []
87    results.extend(results_)
88    super().__init__(
89        self.build_generic(
90            regions=2,
91            results=results,
92            operands=operands,
93            loc=loc,
94            ip=ip))
95    self.regions[0].blocks.append(*[])
96    if hasElse:
97        self.regions[1].blocks.append(*[])
98
99  @property
100  def then_block(self):
101    """Returns the then block of the if operation."""
102    return self.regions[0].blocks[0]
103
104  @property
105  def else_block(self):
106    """Returns the else block of the if operation."""
107    return self.regions[1].blocks[0]
108