1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4"""DSL for constructing affine expressions and maps.
5
6These python wrappers allow construction of affine expressions in a more
7pythonic fashion that is later instantiated as an IR AffineExpr. Separating the
8AST from construction of the map allows for manipulations of symbols and dims
9beyond the scope of one expression.
10
11Affine expression construction:
12  >>> with _ir.Context():
13  ...   s = AffineBuildState()
14  ...   (S.K + S.M).build(s)
15  ...   (S.K * S.M).build(s)
16  ...   (S.K // S.M).build(s)
17  ...   (S.K / S.M).build(s)
18  ...   (S.K % 4).build(s)
19  ...   (D.i + D.j * 4).build(s)
20  ...   s
21  AffineExpr(s0 + s1)
22  AffineExpr(s0 * s1)
23  AffineExpr(s0 floordiv s1)
24  AffineExpr(s0 ceildiv s1)
25  AffineExpr(s0 mod 4)
26  AffineExpr(d0 + d1 * 4)
27  AffineBuildState<
28    symbols={'K': 0, 'M': 1}
29    dims={'i': 0, 'j': 1}>
30
31In the DSL, dimensions and symbols are name-uniqued instances of DimDef and
32SymbolDef. There are shortcut "expando" instances that will create a
33corresponding DimDef/SymbolDef upon accessing an attribute:
34
35Referencing a named dimension:
36
37  >>> D.i
38  Dim(i)
39  >>> D.a is D.b
40  False
41  >>> D.a is D.a
42  True
43
44Referencing a named symbol:
45
46  >>> S.foobar
47  Symbol(foobar)
48  >>> S.a is S.b
49  False
50  >>> S.a is S.a
51  True
52"""
53
54from typing import Callable, Dict, Optional, Tuple, Union
55
56from ..... import ir as _ir
57
58__all__ = [
59    "AffineBuildState",
60    "AffineExprDef",
61    "D",
62    "DimDef",
63    "S",
64    "SymbolDef",
65]
66
67
68class AffineBuildState:
69  """Internal state for the AffineExprDef._create impls.
70
71  Note that a "local" AffineBuildState can be created relative to a "global"
72  AffineBuildState. In that case, any affine expressions built will inherit
73  symbol and dim bindings from the global state and will update both as new
74  ones are discovered. This allows for building expressions across contexts
75  which share a common symbol and dim space.
76  """
77
78  def __init__(self,
79               *,
80               global_state: "AffineBuildState" = None,
81               allow_new_symbols: bool = True,
82               allow_new_dims: bool = True):
83    if not global_state:
84      self.all_symbols = dict()  # type: Dict[str, int]
85      self.all_dims = dict()  # type: Dict[str, int]
86    else:
87      # Alias the global dict.
88      self.all_symbols = global_state.all_symbols
89      self.all_dims = global_state.all_dims
90
91    # Map of symbols and dims in the current build.
92    self.local_symbols = dict()  # type: Dict[str, int]
93    self.local_dims = dict()  # type: Dict[str, int]
94    self.allow_new_symbols = allow_new_symbols
95    self.allow_new_dims = allow_new_dims
96
97  def get_dim(self, dimname: str) -> int:
98    """Gets the dim position given a name."""
99    pos = self.all_dims.get(dimname)
100    if pos is None:
101      if not self.allow_new_dims:
102        raise ValueError(
103            f"New dimensions not allowed in the current affine expression: "
104            f"Requested '{dimname}', Availble: {self.all_dims}")
105      pos = len(self.all_dims)
106      self.all_dims[dimname] = pos
107    self.local_dims[dimname] = pos
108    return pos
109
110  def get_symbol(self, symname: str) -> int:
111    """Geta a symbol position given a name."""
112    pos = self.all_symbols.get(symname)
113    if pos is None:
114      if not self.allow_new_symbols:
115        raise ValueError(
116            f"New symbols not allowed in the current affine expression: "
117            f"Requested '{symname}', Availble: {self.all_symbols}")
118      pos = len(self.all_symbols)
119      self.all_symbols[symname] = pos
120    self.local_symbols[symname] = pos
121    return pos
122
123  @property
124  def local_dim_count(self) -> int:
125    return len(self.local_dims)
126
127  @property
128  def local_symbol_count(self) -> int:
129    return len(self.local_symbols)
130
131  @property
132  def dim_count(self) -> int:
133    return len(self.all_dims)
134
135  @property
136  def symbol_count(self) -> int:
137    return len(self.all_symbols)
138
139  def __repr__(self):
140    lines = [f"AffineBuildState<"]
141    lines.append(f"  symbols={self.local_symbols}")
142    lines.append(f"  dims={self.local_dims}>")
143    return "\n".join(lines)
144
145
146class AffineExprDef:
147  """Base class for an affine expression being defined."""
148
149  def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
150    """Builds the corresponding _ir.AffineExpr from the definitions.
151    """
152    state = AffineBuildState() if state is None else state
153    expr = self._create(state)
154    return expr
155
156  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
157    raise NotImplementedError()
158
159  @staticmethod
160  def coerce_from(py_value):
161    if isinstance(py_value, int):
162      return AffineConstantExpr(py_value)
163    assert isinstance(py_value, AffineExprDef)
164    return py_value
165
166  def visit_affine_exprs(self, callback):
167    """Visits all AffineExprDefs including self."""
168    callback(self)
169
170  def __add__(lhs, rhs):
171    rhs = AffineExprDef.coerce_from(rhs)
172    return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
173
174  def __mul__(lhs, rhs):
175    rhs = AffineExprDef.coerce_from(rhs)
176    return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
177
178  def __mod__(lhs, rhs):
179    rhs = AffineExprDef.coerce_from(rhs)
180    return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
181
182  def __floordiv__(lhs, rhs):
183    rhs = AffineExprDef.coerce_from(rhs)
184    return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
185
186  def __truediv__(lhs, rhs):
187    # TODO: Not really a ceil div - taking liberties for the DSL.
188    rhs = AffineExprDef.coerce_from(rhs)
189    return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
190
191
192class AffineConstantExpr(AffineExprDef):
193  """An affine constant being defined."""
194
195  def __init__(self, value: int):
196    assert isinstance(value, int)
197    self.value = value
198
199  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
200    return _ir.AffineConstantExpr.get(self.value)
201
202  def __repr__(self):
203    return f"Const({self.value})"
204
205
206class AffineBinaryExprDef(AffineExprDef):
207  """An affine binary expression being defined."""
208
209  def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
210    self.ir_ctor = ir_ctor
211    self.lhs = lhs
212    self.rhs = rhs
213
214  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
215    return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
216
217  def visit_affine_exprs(self, callback):
218    """Visits all AffineExprDefs including self."""
219    super().visit_affine_exprs(callback)
220    self.lhs.visit_affine_exprs(callback)
221    self.rhs.visit_affine_exprs(callback)
222
223  def __repr__(self):
224    return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
225
226
227class DimDef(AffineExprDef):
228  """Represents a named dimension.
229
230  """
231  ALL_DIMS = dict()  # type: Dict[str, "DimDef"]
232
233  def __new__(cls, dimname: str):
234    existing = cls.ALL_DIMS.get(dimname)
235    if existing is not None:
236      return existing
237    new = super().__new__(cls)
238    new.dimname = dimname
239    cls.ALL_DIMS[dimname] = new
240    return new
241
242  def __repr__(self):
243    return f"Dim({self.dimname})"
244
245  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
246    pos = state.get_dim(self.dimname)
247    return _ir.AffineDimExpr.get(position=pos)
248
249  @classmethod
250  def create_expando(cls):
251    """Create an expando class that creates unique symbols based on attr access.
252    """
253
254    class ExpandoDims:
255
256      def __getattr__(self, n):
257        return cls(n)
258
259    return ExpandoDims()
260
261
262class SymbolDef(AffineExprDef):
263  """Represents a named symbol.
264
265    >>> s1 = SymbolDef("s1")
266    >>> s1
267    Symbol(s1)
268    >>> s2 = SymbolDef("s2")
269    >>> s1 is s2
270    False
271    >>> s1 is SymbolDef("s1")
272    True
273  """
274  ALL_SYMBOLS = dict()  # type: Dict[str, "SymbolDef"]
275
276  def __new__(cls, symname: str):
277    existing = cls.ALL_SYMBOLS.get(symname)
278    if existing is not None:
279      return existing
280    new = super().__new__(cls)
281    new.symname = symname
282    cls.ALL_SYMBOLS[symname] = new
283    return new
284
285  def __repr__(self):
286    return f"Symbol({self.symname})"
287
288  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
289    pos = state.get_symbol(self.symname)
290    return _ir.AffineSymbolExpr.get(position=pos)
291
292  @classmethod
293  def create_expando(cls):
294    """Create an expando class that creates unique symbols based on attr access.
295    """
296
297    class ExpandoSymbols:
298
299      def __getattr__(self, n):
300        return cls(n)
301
302    return ExpandoSymbols()
303
304
305# Global accessor for on-demand dims and symbols.
306D = DimDef.create_expando()
307S = SymbolDef.create_expando()
308