1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4"""Model classes representing a tensor comprehension. 5 6These classes model the language more at an AST level as evaluated. Reasoning 7about it typically involves processing this form into config objects that 8represent actual op definitions (i.e. YAML). 9""" 10 11from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple 12from enum import Enum 13 14from ..... import ir as _ir 15from .affine import * 16from .scalar_expr import * 17from .types import * 18from .yaml_helper import * 19 20############################################################################### 21# Tensor expression nodes. 22############################################################################### 23 24 25class TensorExpression: 26 """An expression that can appear on the RHS of a comprehension.""" 27 28 def to_scalar_expression(self) -> ScalarExpression: 29 raise NotImplementedError() 30 31 def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): 32 """Visits all tensor expression reachable by the expression.""" 33 callback(self) 34 35 def collect_dim_uses(self, uses: Set["DimDef"]): 36 """Collects all DimDefs reachable through this expression.""" 37 38 def visit_dim_def(dim_def: AffineExprDef): 39 if isinstance(dim_def, DimDef): 40 uses.add(dim_def) 41 42 def visit_affine_exprs(expr: "TensorExpression"): 43 if isinstance(expr, TensorUse): 44 for ind in expr.indices: 45 ind.visit_affine_exprs(visit_dim_def) 46 if isinstance(expr, TensorReduceFn): 47 for ind in expr.reduce_fn.reduce_dims: 48 ind.visit_affine_exprs(visit_dim_def) 49 50 self.visit_tensor_exprs(visit_affine_exprs) 51 52 def collect_tensor_uses(self, uses: Set["TensorUse"]): 53 """Collects all TensorUses reachable through this expression.""" 54 55 def visit_tensor_use(expr: "TensorExpression"): 56 if isinstance(expr, TensorUse): 57 uses.add(expr) 58 59 self.visit_tensor_exprs(visit_tensor_use) 60 61 def collect_indices(self, indices: Set["index"]): 62 """Collects all index accesses reachable through this expression.""" 63 64 def visit_index(expr: "TensorExpression"): 65 if isinstance(expr, index): 66 indices.add(expr) 67 68 self.visit_tensor_exprs(visit_index) 69 70 def collect_scalar_uses(self, uses: Set["ScalarDef"]): 71 """Collects all ScalarDefs reachable through this expression.""" 72 73 def visit_scalar_def(expr: "TensorExpression"): 74 if isinstance(expr, ScalarDef): 75 uses.add(expr) 76 77 self.visit_tensor_exprs(visit_scalar_def) 78 79 def __add__(self, rhs: "TensorExpression") -> "TensorExpression": 80 return BinaryFn.add(self, rhs) 81 82 def __mul__(self, rhs) -> "TensorExpression": 83 return BinaryFn.mul(self, rhs) 84 85 def __sub__(self, rhs) -> "TensorExpression": 86 return BinaryFn.sub(self, rhs) 87 88 def __hash__(self): 89 return hash(id(self)) 90 91 92class TensorUse(TensorExpression): 93 """A used tensor represented by its (tensor_name, indices). 94 95 Note that forming a comprehension via direct assignment is performed through 96 __setitem__ on the TensorDef level. However, performing a reduction with 97 compound ops (+=, *=, etc) is done by doing a: 98 TensorDef.__getitem__ 99 TensorUse.__iadd__ 100 TensorDef.__setitem__ 101 """ 102 103 def __init__(self, operand_def: "OperandDef", 104 indices: Sequence[AffineExprDef]): 105 self.operand_def = operand_def 106 self.indices = tuple(indices) 107 108 def to_scalar_expression(self) -> ScalarExpression: 109 return ScalarArg(self.tensor_name).expr() 110 111 @property 112 def tensor_name(self) -> str: 113 name = self.operand_def.name 114 assert name is not None, "TensorDef not registered with an op" 115 return name 116 117 def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: 118 # Computes the reduction dims for implicit reductions. Assumes that the rhs 119 # is the expression being reduced and self is being reduced into. Any 120 # indices referenced on the rhs and not in self are considered reduction 121 # dims and will be ordered as encountered on the rhs. 122 rhs_dims = set() 123 lhs_dims = set() 124 rhs.collect_dim_uses(rhs_dims) 125 self.collect_dim_uses(lhs_dims) 126 return rhs_dims - lhs_dims 127 128 def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": 129 return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) 130 131 def __repr__(self): 132 return (f"{self.operand_def.name}" 133 f"[{', '.join([repr(i) for i in self.indices])}]") 134 135 136class TensorFn(TensorExpression): 137 """Application of a tensor function.""" 138 139 def __init__(self, kind: "FunctionKind", name: Optional[str], 140 operand_def: Optional["OperandDef"], type_var: Optional[TypeVar], 141 args: Sequence[TensorExpression]): 142 if bool(name) + bool(operand_def) != 1: 143 raise ValueError("One of 'name', 'operand_def' must be specified") 144 self.name = name 145 self.kind = kind 146 self.operand_def = operand_def 147 self.type_var = type_var 148 self.args = args 149 150 def to_scalar_expression(self) -> ScalarExpression: 151 if self.operand_def: 152 assert self.operand_def.name, "TensorFn not registered with an op" 153 attr_name = self.operand_def.name if self.operand_def else None 154 args = [arg.to_scalar_expression() for arg in self.args] 155 return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() 156 157 def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): 158 super().visit_tensor_exprs(callback) 159 for arg in self.args: 160 arg.visit_tensor_exprs(callback) 161 162 def __repr__(self): 163 name = self.operand_def.name if self.operand_def else self.name 164 return (f"{self.kind.name}.{name}(type_var={self.type_var}, " 165 f"args={', '.join(repr(a) for a in self.args)})") 166 167 168class TensorReduceFn(TensorExpression): 169 """Application of a reduction function. 170 171 This captures the lhs (initial value) separately from the rhs. 172 """ 173 174 def __init__(self, reduce_use: "ReduceFnUse", 175 args: Sequence[TensorExpression]): 176 self.reduce_use = reduce_use 177 self.lhs = None # type: Optional[TensorUse] 178 self.args = args 179 180 def to_scalar_expression(self) -> ScalarExpression: 181 if self.lhs is None: 182 raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been " 183 f"bound to its lhs: {self}") 184 full_args = [self.lhs.to_scalar_expression() 185 ] + [arg.to_scalar_expression() for arg in self.args] 186 fn_name = None 187 attr_name = None 188 if self.reduce_use.binary_fn: 189 fn_name = self.reduce_use.binary_fn.fn_name 190 if self.reduce_use.binary_attr: 191 attr_name = self.reduce_use.binary_attr.operand_def.name 192 return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, 193 full_args).expr() 194 195 def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): 196 for arg in self.args: 197 arg.visit_tensor_exprs(callback) 198 199 def __repr__(self): 200 return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" 201 202 203class const(TensorExpression): 204 """Returns the given constant floating point or integer value.""" 205 206 def __init__(self, value: Any): 207 with _ir.Context(): 208 if isinstance(value, float): 209 self.value = str(_ir.FloatAttr.get_f64(float(value))) 210 elif isinstance(value, int): 211 self.value = str( 212 _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) 213 else: 214 raise ValueError(f"const requires int or float but got {type(value)}") 215 216 def to_scalar_expression(self) -> ScalarExpression: 217 return ScalarConst(self.value).expr() 218 219 def __repr__(self): 220 return f"const({self.value})" 221 222 223class index(TensorExpression): 224 """Returns the iteration index for a given dimension name. 225 226 Resolves the given dimension name to obtain its position in the iteration 227 domain of the operation. 228 """ 229 230 def __init__(self, dim: DimDef): 231 self.dim_def = dim 232 self.dim = -1 233 234 def resolve_dimension_name(self, affine_state: AffineBuildState): 235 self.dim = affine_state.get_dim(self.dim_def.dimname) 236 237 def to_scalar_expression(self) -> ScalarExpression: 238 assert self.dim != -1, "Dimension name not resolved" 239 return ScalarIndex(self.dim).expr() 240 241 def __repr__(self): 242 return f"index({repr(self.dim)})" 243 244 245############################################################################### 246# Function types and function definitions. 247############################################################################### 248 249 250class FunctionKind(Enum): 251 UNARY = 0 252 BINARY = 1 253 TYPE = 2 254 255 256class UnaryFnType: 257 """Unary function. 258 259 A unary function takes one tensor expression and returns the 260 function evaluation result. 261 """ 262 263 def __init__(self, fn_name: str): 264 self.fn_name = fn_name 265 266 def __call__(self, arg: TensorExpression) -> "TensorFn": 267 return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) 268 269 def __repr__(self): 270 return f"{self.fn_name}" 271 272 273class UnaryFn: 274 """Unary function namespace.""" 275 exp = UnaryFnType("exp") 276 log = UnaryFnType("log") 277 abs = UnaryFnType("abs") 278 ceil = UnaryFnType("ceil") 279 floor = UnaryFnType("floor") 280 negf = UnaryFnType("negf") 281 282 283class BinaryFnType: 284 """Binary function. 285 286 A binary function takes two tensor expressions and returns the 287 function evaluation result. 288 """ 289 290 def __init__(self, fn_name: str): 291 self.fn_name = fn_name 292 293 def __call__(self, arg0: TensorExpression, 294 arg1: TensorExpression) -> "TensorFn": 295 return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) 296 297 def __repr__(self): 298 return f"{self.fn_name}" 299 300 301class BinaryFn: 302 """Binary function namespace. 303 304 As the integer types are signless, signedness is implement by different 305 functions that treat integers as signed or unsigned values. 306 307 Examples: 308 - max -> `arith.MaxSIOp` 309 - max_unsinged -> `arith.MaxUIOp` 310 """ 311 add = BinaryFnType("add") 312 sub = BinaryFnType("sub") 313 mul = BinaryFnType("mul") 314 max_signed = BinaryFnType("max_signed") 315 min_signed = BinaryFnType("min_signed") 316 max_unsigned = BinaryFnType("max_unsigned") 317 min_unsigned = BinaryFnType("min_unsigned") 318 319 320class TypeFnType: 321 """Type conversion function. 322 323 A type conversion function takes a target type and a tensor expression and 324 returns the casted tensor expression. 325 """ 326 327 def __init__(self, fn_name: str): 328 self.fn_name = fn_name 329 330 def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": 331 return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) 332 333 def __repr__(self): 334 return f"{self.fn_name}" 335 336 337class TypeFn: 338 """Type conversion function namespace. 339 340 As the integer types are signless, signedness is implement by different cast 341 functions that treat integers as signed (`cast_signed`) or unsigned 342 (`cast_unsigned`) values. 343 344 Examples: 345 - cast_signed(I32 -> I64) -> `arith.ExtSIOp` 346 - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` 347 """ 348 cast_signed = TypeFnType("cast_signed") 349 cast_unsigned = TypeFnType("cast_unsigned") 350 351 352class ReduceFnUse: 353 """Reduction function use. 354 355 A reduction use specifies the reduction function and dimensions. 356 """ 357 358 def __init__(self, binary_fn: Optional[BinaryFnType], 359 binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef): 360 if bool(binary_fn) + bool(binary_attr) != 1: 361 raise ValueError("One of 'binary_fn', 'binary_attr' must be specified") 362 self.binary_fn = binary_fn 363 self.binary_attr = binary_attr 364 self.reduce_dims = reduce_dims 365 366 def __call__(self, *args: TensorExpression) -> "TensorReduceFn": 367 return TensorReduceFn(self, args) 368 369 def __repr__(self): 370 fn = self.binary_fn if self.binary_fn else self.binary_attr 371 return ( 372 f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})") 373 374 375class ReduceFnType: 376 """Reduction function. 377 378 A binary function that reduces its RHS into its LHS. 379 """ 380 381 def __init__(self, binary_fn: BinaryFnType): 382 if not isinstance(binary_fn, BinaryFnType): 383 raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}") 384 self.binary_fn = binary_fn 385 386 def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: 387 return ReduceFnUse(self.binary_fn, None, *reduce_dims) 388 389 def __repr__(self): 390 return f"reduce_{repr(self.binary_fn)}" 391 392 393class ReduceFn: 394 add = ReduceFnType(BinaryFn.add) 395 mul = ReduceFnType(BinaryFn.mul) 396 max_signed = ReduceFnType(BinaryFn.max_signed) 397 min_signed = ReduceFnType(BinaryFn.min_signed) 398 max_unsigned = ReduceFnType(BinaryFn.max_unsigned) 399 min_unsigned = ReduceFnType(BinaryFn.min_unsigned) 400 401 402############################################################################### 403# Operand definitions. 404############################################################################### 405 406 407class OperandKind(Enum): 408 INPUT_TENSOR = 0 409 SCALAR = 1 410 OUTPUT_TENSOR = 2 411 INDEX_ATTR = 3 412 UNARY_FN_ATTR = 4 413 BINARY_FN_ATTR = 5 414 TYPE_FN_ATTR = 6 415 416 417class OperandDef: 418 """Definition of an operand passed to an operation. 419 420 Keep the meta information of Tensor, Scalar, and Attribute operands and 421 provide the shared registration functionality. 422 """ 423 424 def __init__(self, 425 kind: OperandKind, 426 type_var: Optional[TypeVar] = None, 427 size_exprs: Optional[Sequence[AffineExprDef]] = None, 428 index_dims: Optional[Sequence[DimDef]] = None, 429 default_indices: Optional[Sequence[int]] = None, 430 default_fn: Optional[str] = None): 431 if type_var and not isinstance(type_var, TypeVar): 432 raise ValueError( 433 f"OperandDef requires a TypeVar but got {repr(type_var)}") 434 self.owner = None # type: Optional["LinalgOpDef"] 435 self.type_var = type_var 436 self.size_exprs = size_exprs 437 self.index_dims = index_dims 438 self.default_indices = default_indices 439 self.default_fn = default_fn 440 self.kind = kind 441 self.name = None # type: Optional[str] 442 self.registered_index = -1 # type: int 443 444 def attach(self, index: int, name: str, owner: "LinalgOpDef"): 445 if self.owner: 446 raise ValueError(f"OperandDef already registered with an op: {self}") 447 self.registered_index = index 448 self.name = name 449 self.owner = owner 450 451 def is_input(self) -> bool: 452 return (self.kind == OperandKind.SCALAR or 453 self.kind == OperandKind.INPUT_TENSOR) 454 455 def is_tensor(self) -> bool: 456 return (self.kind == OperandKind.INPUT_TENSOR or 457 self.kind == OperandKind.OUTPUT_TENSOR) 458 459 def is_attribute(self) -> bool: 460 return (self.kind == OperandKind.INDEX_ATTR or 461 self.kind == OperandKind.UNARY_FN_ATTR or 462 self.kind == OperandKind.BINARY_FN_ATTR or 463 self.kind == OperandKind.TYPE_FN_ATTR) 464 465 def __hash__(self): 466 return hash(id(self)) 467 468 def __repr__(self): 469 return (f"{self.name}:OperandDef(kind={self.kind.name}, " 470 f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, " 471 f"index_dims={self.index_dims}, " 472 f"default_indices={self.default_indices}, " 473 f"default_fn={self.default_fn})") 474 475 476class TensorDef: 477 """Tensor operand definition. 478 479 Tensor operands are indexed using the associated indexing_map when forwarded 480 to the body of the structured op. A unique name identifies the tensor operands 481 and an index determines their position in the operation's parameter list. A 482 tensor definition takes type, a shape, and an optional flag to mark output 483 tensors. Additionally, a tuple of index dimensions may be used to map the 484 tensor to the loop dimensions of the operation. This mapping is needed to 485 compute the indexing map of shape-only tensors that have no uses. 486 """ 487 488 def __init__(self, 489 type_var: TypeVar, 490 *shape: AffineExprDef, 491 index_dims: Optional[Sequence[DimDef]] = None, 492 output: bool = False): 493 if index_dims and len(shape) != len(index_dims): 494 raise ValueError(f"Expected the shape rank {len(shape)} to match the " 495 f"number of index_dims {len(index_dims)}") 496 if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): 497 raise ValueError(f"TensorDef requires index dims of type DimDef but " 498 f"got {index_dims}") 499 kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR 500 self.operand_def = OperandDef( 501 kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) 502 503 def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: 504 assert self.operand_def.owner, "TensorDef is not registered with an op" 505 state = AffineBuildState( 506 global_state=self.operand_def.owner._affine_state, 507 allow_new_symbols=False) 508 if not isinstance(dims, tuple): 509 dims = (dims,) # Handle single subscript case. 510 # Special case: (None) is a 0d-scalar use. 511 if dims == (None,): 512 dims = () 513 514 exprs = [] 515 for expr_def in dims: 516 if not isinstance(expr_def, AffineExprDef): 517 raise KeyError( 518 "A TensorDef can only be subscripted by a tuple of affine dims") 519 exprs.append(expr_def) 520 return TensorUse(self.operand_def, exprs) 521 522 def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): 523 """Creates a new 1:1 comprehension by binding this tensor to an expression. 524 525 Note that due to the way assignment works in Python, we have to capture 526 direct assignment as a setitem on the TensorDef. 527 """ 528 if not isinstance(value, TensorExpression): 529 raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. " 530 f"Got: {repr(value)}") 531 use = self[dims] 532 comp = Comprehension((use, value)) 533 self.operand_def.owner.comprehensions.append(comp) 534 535 536class ScalarDef(TensorExpression): 537 """Scalar operand definition. 538 539 Scalar operands are forwarded to the body of the structured op as they are. 540 A unique name identifies the scalars and an index determines their position in 541 the operation's parameter list. 542 """ 543 544 def __init__(self, type_var: TypeVar): 545 self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) 546 547 @property 548 def scalar_name(self) -> str: 549 name = self.operand_def.name 550 assert name is not None, "ScalarDef not registered with an op" 551 return name 552 553 def to_scalar_expression(self) -> ScalarExpression: 554 return ScalarArg(self.scalar_name).expr() 555 556 557class IndexAttrDef: 558 """Index attribute definition. 559 560 Index attributes provide a way to define and set symbols that can be used in 561 indexing expressions. Every attribute specifies a tuple of symbols that at 562 compile-time are replaced by integer values as well as their default values. 563 """ 564 565 def __init__(self, *sizes: SymbolDef, default: Sequence[int]): 566 if any(not isinstance(size, SymbolDef) for size in sizes): 567 raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef " 568 f"but got {sizes}") 569 if any(not isinstance(default_val, int) for default_val in default): 570 raise ValueError(f"IndexAttrDef requires default values of type int " 571 f"but got {default}") 572 if len(sizes) != len(default): 573 raise ValueError(f"IndexAttrDef expects {len(sizes)} default values " 574 f"but got {len(default)}") 575 self.operand_def = OperandDef( 576 OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default) 577 578 579class UnaryFnAttrDef: 580 """Unary function attribute definition. 581 582 Unary function attributes provide a way to make the arithmetic computation 583 parametrizable. Every attribute specifies a default unary function 584 that may be overwritten at operation instantiation time. 585 """ 586 587 def __init__(self, default: "UnaryFnType"): 588 if not isinstance(default, UnaryFnType): 589 raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType " 590 f"but got {default}") 591 self.operand_def = OperandDef( 592 OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name) 593 594 def __call__(self, arg: TensorExpression) -> TensorFn: 595 return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) 596 597 598class BinaryFnAttrDef: 599 """Binary function attribute definition. 600 601 Binary function attributes provide a way to make the arithmetic computation 602 parametrizable. Every attribute specifies a default binary function 603 that may be overwritten at operation instantiation time. 604 """ 605 606 def __init__(self, default: "BinaryFnType"): 607 if not isinstance(default, BinaryFnType): 608 raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType " 609 f"but got {default}") 610 self.operand_def = OperandDef( 611 OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name) 612 613 def __call__(self, arg0: TensorExpression, 614 arg1: TensorExpression) -> TensorFn: 615 return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, 616 [arg0, arg1]) 617 618 def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: 619 return ReduceFnUse(None, self, *reduce_dims) 620 621 622class TypeFnAttrDef: 623 """Type conversion function attribute definition. 624 625 Type conversion function attributes provide a way to make type conversions 626 parameterizable. Every attribute specifies a default type conversion function 627 that may be overwritten at operation instantiation time. 628 """ 629 630 def __init__(self, default: "TypeFnType"): 631 if not isinstance(default, TypeFnType): 632 raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType " 633 f"but got {default}") 634 self.operand_def = OperandDef( 635 OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name) 636 637 def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: 638 return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) 639 640 641############################################################################### 642# Operation definition. 643############################################################################### 644 645 646class Comprehension: 647 """Represents a single comprehension.""" 648 649 def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): 650 self.definitions = list() # List[TensorUse] 651 self.values = list() # List[TensorExpression] 652 653 # Find the lhs to reduction rhs. 654 for assign, value in bindings: 655 if isinstance(value, TensorReduceFn): 656 if value.lhs: 657 raise ValueError(f"Reduction expression already assigns: {value}") 658 value.lhs = assign 659 self.definitions.append(assign) 660 self.values.append(value) 661 662 @property 663 def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: 664 """Gets the reduction dims for the comprehension or None.""" 665 result = set() 666 for use in self.values: 667 if isinstance(use, TensorReduceFn): 668 result.add(use.reduce_use.reduce_dims) 669 else: 670 result.add(tuple()) 671 return result 672 673 def __repr__(self): 674 if len(self.definitions) > 1: 675 defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" 676 values_repr = f"({', '.join(repr(v) for v in self.values)})" 677 else: 678 defs_repr = f"{repr(self.definitions[0])}" 679 values_repr = f"{repr(self.values[0])}" 680 681 return f"{defs_repr} = {values_repr}" 682 683 684class OpInterfaceDef: 685 """An interface that an op implements.""" 686 687 def __init__(self, cpp_name: str): 688 self.cpp_name = cpp_name 689 690 691ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") 692ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") 693FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") 694 695 696class OpDefinitionDef: 697 """A method that an op implements.""" 698 699 def __init__(self, def_name: str): 700 self.def_name = def_name 701 702 703Canonicalizer = OpDefinitionDef("hasCanonicalizer") 704 705 706class OpMetadataDef(YAMLObject): 707 """Metadata about the op (generally not behavior impacting).""" 708 yaml_tag = "!LinalgOpMetadata" 709 710 def __init__(self, name: str, cpp_class_name: Optional[str], 711 doc: Optional[str]): 712 self.name = name 713 self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name 714 self.doc = doc 715 self.implements = [] # type: List[OpInterfaceDef] 716 self.defines = [] # type: List[OpDefinitionsDef] 717 718 def to_yaml_custom_dict(self): 719 d = dict( 720 name=self.name, 721 cpp_class_name=self.cpp_class_name, 722 doc=self.doc, 723 ) 724 if self.implements: 725 d["implements"] = [intr.cpp_name for intr in self.implements] 726 if self.defines: 727 d["defines"] = [defi.def_name for defi in self.defines] 728 return d 729 730 731class LinalgOpDef: 732 """Definition of a linalg op.""" 733 734 def __init__(self, 735 name: str, 736 cpp_class_name: Optional[str] = None, 737 doc: Optional[str] = None): 738 self.metadata = OpMetadataDef( 739 name=name, cpp_class_name=cpp_class_name, doc=doc) 740 self.registered_operands = dict() # type: Dict[str, OperandDef] 741 self.domain = list() # type: List[DimDef] 742 self.comprehensions = list() # type: List[Comprehension] 743 self._affine_state = AffineBuildState() 744 745 def add_operand(self, name: str, operand: OperandDef): 746 """Registers an operand.""" 747 if name in self.registered_operands: 748 raise ValueError(f"The operand {name} is already registered " 749 f"to {self.registered_operands['name']}") 750 structured_op_methods = [ 751 "inputs", "outputs", "result_tensors", "region", "iterator_types", 752 "indexing_maps", "getRegionBuilder", "getLibraryCallName" 753 ] 754 if operand.is_attribute() and name in structured_op_methods: 755 raise ValueError(f"The attribute name {name} conflicts with a structured " 756 f"op method name") 757 # Ensure output tensors are registered after input tensors and scalars and 758 # attributes are registered after all other operand types. 759 if operand.is_input() and any( 760 not op_def.is_input() for op_def in self.registered_operands.values()): 761 raise ValueError(f"Input {name} registered after an output or attribute") 762 if operand.kind == OperandKind.OUTPUT_TENSOR and any( 763 op_def.is_attribute() for op_def in self.registered_operands.values()): 764 raise ValueError(f"Output {name} registered after an attribute") 765 operand.attach(len(self.registered_operands), name, self) 766 self.registered_operands[name] = operand 767 768 def __repr__(self): 769 lines = [ 770 f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name}," 771 ] 772 for name, operand in self.registered_operands.items(): 773 lines.append(f" {operand}") 774 if self.comprehensions: 775 lines[-1] += " {" 776 for comprehension in self.comprehensions: 777 lines.append(f" {comprehension}") 778 lines.append("}") 779 return "\n".join(lines) 780