1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4"""DSL for constructing affine expressions and maps. 5 6These python wrappers allow construction of affine expressions in a more 7pythonic fashion that is later instantiated as an IR AffineExpr. Separating the 8AST from construction of the map allows for manipulations of symbols and dims 9beyond the scope of one expression. 10 11Affine expression construction: 12 >>> with _ir.Context(): 13 ... s = AffineBuildState() 14 ... (S.K + S.M).build(s) 15 ... (S.K * S.M).build(s) 16 ... (S.K // S.M).build(s) 17 ... (S.K / S.M).build(s) 18 ... (S.K % 4).build(s) 19 ... (D.i + D.j * 4).build(s) 20 ... s 21 AffineExpr(s0 + s1) 22 AffineExpr(s0 * s1) 23 AffineExpr(s0 floordiv s1) 24 AffineExpr(s0 ceildiv s1) 25 AffineExpr(s0 mod 4) 26 AffineExpr(d0 + d1 * 4) 27 AffineBuildState< 28 symbols={'K': 0, 'M': 1} 29 dims={'i': 0, 'j': 1}> 30 31In the DSL, dimensions and symbols are name-uniqued instances of DimDef and 32SymbolDef. There are shortcut "expando" instances that will create a 33corresponding DimDef/SymbolDef upon accessing an attribute: 34 35Referencing a named dimension: 36 37 >>> D.i 38 Dim(i) 39 >>> D.a is D.b 40 False 41 >>> D.a is D.a 42 True 43 44Referencing a named symbol: 45 46 >>> S.foobar 47 Symbol(foobar) 48 >>> S.a is S.b 49 False 50 >>> S.a is S.a 51 True 52""" 53 54from typing import Callable, Dict, Optional, Tuple, Union 55 56from ..... import ir as _ir 57 58__all__ = [ 59 "AffineBuildState", 60 "AffineExprDef", 61 "D", 62 "DimDef", 63 "S", 64 "SymbolDef", 65] 66 67 68class AffineBuildState: 69 """Internal state for the AffineExprDef._create impls. 70 71 Note that a "local" AffineBuildState can be created relative to a "global" 72 AffineBuildState. In that case, any affine expressions built will inherit 73 symbol and dim bindings from the global state and will update both as new 74 ones are discovered. This allows for building expressions across contexts 75 which share a common symbol and dim space. 76 """ 77 78 def __init__(self, 79 *, 80 global_state: "AffineBuildState" = None, 81 allow_new_symbols: bool = True, 82 allow_new_dims: bool = True): 83 if not global_state: 84 self.all_symbols = dict() # type: Dict[str, int] 85 self.all_dims = dict() # type: Dict[str, int] 86 else: 87 # Alias the global dict. 88 self.all_symbols = global_state.all_symbols 89 self.all_dims = global_state.all_dims 90 91 # Map of symbols and dims in the current build. 92 self.local_symbols = dict() # type: Dict[str, int] 93 self.local_dims = dict() # type: Dict[str, int] 94 self.allow_new_symbols = allow_new_symbols 95 self.allow_new_dims = allow_new_dims 96 97 def get_dim(self, dimname: str) -> int: 98 """Gets the dim position given a name.""" 99 pos = self.all_dims.get(dimname) 100 if pos is None: 101 if not self.allow_new_dims: 102 raise ValueError( 103 f"New dimensions not allowed in the current affine expression: " 104 f"Requested '{dimname}', Availble: {self.all_dims}") 105 pos = len(self.all_dims) 106 self.all_dims[dimname] = pos 107 self.local_dims[dimname] = pos 108 return pos 109 110 def get_symbol(self, symname: str) -> int: 111 """Geta a symbol position given a name.""" 112 pos = self.all_symbols.get(symname) 113 if pos is None: 114 if not self.allow_new_symbols: 115 raise ValueError( 116 f"New symbols not allowed in the current affine expression: " 117 f"Requested '{symname}', Availble: {self.all_symbols}") 118 pos = len(self.all_symbols) 119 self.all_symbols[symname] = pos 120 self.local_symbols[symname] = pos 121 return pos 122 123 @property 124 def local_dim_count(self) -> int: 125 return len(self.local_dims) 126 127 @property 128 def local_symbol_count(self) -> int: 129 return len(self.local_symbols) 130 131 @property 132 def dim_count(self) -> int: 133 return len(self.all_dims) 134 135 @property 136 def symbol_count(self) -> int: 137 return len(self.all_symbols) 138 139 def __repr__(self): 140 lines = [f"AffineBuildState<"] 141 lines.append(f" symbols={self.local_symbols}") 142 lines.append(f" dims={self.local_dims}>") 143 return "\n".join(lines) 144 145 146class AffineExprDef: 147 """Base class for an affine expression being defined.""" 148 149 def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: 150 """Builds the corresponding _ir.AffineExpr from the definitions. 151 """ 152 state = AffineBuildState() if state is None else state 153 expr = self._create(state) 154 return expr 155 156 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 157 raise NotImplementedError() 158 159 @staticmethod 160 def coerce_from(py_value): 161 if isinstance(py_value, int): 162 return AffineConstantExpr(py_value) 163 assert isinstance(py_value, AffineExprDef) 164 return py_value 165 166 def visit_affine_exprs(self, callback): 167 """Visits all AffineExprDefs including self.""" 168 callback(self) 169 170 def __add__(lhs, rhs): 171 rhs = AffineExprDef.coerce_from(rhs) 172 return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) 173 174 def __mul__(lhs, rhs): 175 rhs = AffineExprDef.coerce_from(rhs) 176 return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) 177 178 def __mod__(lhs, rhs): 179 rhs = AffineExprDef.coerce_from(rhs) 180 return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) 181 182 def __floordiv__(lhs, rhs): 183 rhs = AffineExprDef.coerce_from(rhs) 184 return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) 185 186 def __truediv__(lhs, rhs): 187 # TODO: Not really a ceil div - taking liberties for the DSL. 188 rhs = AffineExprDef.coerce_from(rhs) 189 return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) 190 191 192class AffineConstantExpr(AffineExprDef): 193 """An affine constant being defined.""" 194 195 def __init__(self, value: int): 196 assert isinstance(value, int) 197 self.value = value 198 199 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 200 return _ir.AffineConstantExpr.get(self.value) 201 202 def __repr__(self): 203 return f"Const({self.value})" 204 205 206class AffineBinaryExprDef(AffineExprDef): 207 """An affine binary expression being defined.""" 208 209 def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): 210 self.ir_ctor = ir_ctor 211 self.lhs = lhs 212 self.rhs = rhs 213 214 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 215 return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) 216 217 def visit_affine_exprs(self, callback): 218 """Visits all AffineExprDefs including self.""" 219 super().visit_affine_exprs(callback) 220 self.lhs.visit_affine_exprs(callback) 221 self.rhs.visit_affine_exprs(callback) 222 223 def __repr__(self): 224 return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" 225 226 227class DimDef(AffineExprDef): 228 """Represents a named dimension. 229 230 """ 231 ALL_DIMS = dict() # type: Dict[str, "DimDef"] 232 233 def __new__(cls, dimname: str): 234 existing = cls.ALL_DIMS.get(dimname) 235 if existing is not None: 236 return existing 237 new = super().__new__(cls) 238 new.dimname = dimname 239 cls.ALL_DIMS[dimname] = new 240 return new 241 242 def __repr__(self): 243 return f"Dim({self.dimname})" 244 245 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 246 pos = state.get_dim(self.dimname) 247 return _ir.AffineDimExpr.get(position=pos) 248 249 @classmethod 250 def create_expando(cls): 251 """Create an expando class that creates unique symbols based on attr access. 252 """ 253 254 class ExpandoDims: 255 256 def __getattr__(self, n): 257 return cls(n) 258 259 return ExpandoDims() 260 261 262class SymbolDef(AffineExprDef): 263 """Represents a named symbol. 264 265 >>> s1 = SymbolDef("s1") 266 >>> s1 267 Symbol(s1) 268 >>> s2 = SymbolDef("s2") 269 >>> s1 is s2 270 False 271 >>> s1 is SymbolDef("s1") 272 True 273 """ 274 ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] 275 276 def __new__(cls, symname: str): 277 existing = cls.ALL_SYMBOLS.get(symname) 278 if existing is not None: 279 return existing 280 new = super().__new__(cls) 281 new.symname = symname 282 cls.ALL_SYMBOLS[symname] = new 283 return new 284 285 def __repr__(self): 286 return f"Symbol({self.symname})" 287 288 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 289 pos = state.get_symbol(self.symname) 290 return _ir.AffineSymbolExpr.get(position=pos) 291 292 @classmethod 293 def create_expando(cls): 294 """Create an expando class that creates unique symbols based on attr access. 295 """ 296 297 class ExpandoSymbols: 298 299 def __getattr__(self, n): 300 return cls(n) 301 302 return ExpandoSymbols() 303 304 305# Global accessor for on-demand dims and symbols. 306D = DimDef.create_expando() 307S = SymbolDef.create_expando() 308