1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4"""Model classes representing a tensor comprehension.
5
6These classes model the language more at an AST level as evaluated. Reasoning
7about it typically involves processing this form into config objects that
8represent actual op definitions (i.e. YAML).
9"""
10
11from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
12from enum import Enum
13
14from ..... import ir as _ir
15from .affine import *
16from .scalar_expr import *
17from .types import *
18from .yaml_helper import *
19
20###############################################################################
21# Tensor expression nodes.
22###############################################################################
23
24
25class TensorExpression:
26  """An expression that can appear on the RHS of a comprehension."""
27
28  def to_scalar_expression(self) -> ScalarExpression:
29    raise NotImplementedError()
30
31  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
32    """Visits all tensor expression reachable by the expression."""
33    callback(self)
34
35  def collect_dim_uses(self, uses: Set["DimDef"]):
36    """Collects all DimDefs reachable through this expression."""
37
38    def visit_dim_def(dim_def: AffineExprDef):
39      if isinstance(dim_def, DimDef):
40        uses.add(dim_def)
41
42    def visit_affine_exprs(expr: "TensorExpression"):
43      if isinstance(expr, TensorUse):
44        for ind in expr.indices:
45          ind.visit_affine_exprs(visit_dim_def)
46      if isinstance(expr, TensorReduceFn):
47        for ind in expr.reduce_fn.reduce_dims:
48          ind.visit_affine_exprs(visit_dim_def)
49
50    self.visit_tensor_exprs(visit_affine_exprs)
51
52  def collect_tensor_uses(self, uses: Set["TensorUse"]):
53    """Collects all TensorUses reachable through this expression."""
54
55    def visit_tensor_use(expr: "TensorExpression"):
56      if isinstance(expr, TensorUse):
57        uses.add(expr)
58
59    self.visit_tensor_exprs(visit_tensor_use)
60
61  def collect_indices(self, indices: Set["index"]):
62    """Collects all index accesses reachable through this expression."""
63
64    def visit_index(expr: "TensorExpression"):
65      if isinstance(expr, index):
66        indices.add(expr)
67
68    self.visit_tensor_exprs(visit_index)
69
70  def collect_scalar_uses(self, uses: Set["ScalarDef"]):
71    """Collects all ScalarDefs reachable through this expression."""
72
73    def visit_scalar_def(expr: "TensorExpression"):
74      if isinstance(expr, ScalarDef):
75        uses.add(expr)
76
77    self.visit_tensor_exprs(visit_scalar_def)
78
79  def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
80    return BinaryFn.add(self, rhs)
81
82  def __mul__(self, rhs) -> "TensorExpression":
83    return BinaryFn.mul(self, rhs)
84
85  def __sub__(self, rhs) -> "TensorExpression":
86    return BinaryFn.sub(self, rhs)
87
88  def __hash__(self):
89    return hash(id(self))
90
91
92class TensorUse(TensorExpression):
93  """A used tensor represented by its (tensor_name, indices).
94
95  Note that forming a comprehension via direct assignment is performed through
96  __setitem__ on the TensorDef level. However, performing a reduction with
97  compound ops (+=, *=, etc) is done by doing a:
98    TensorDef.__getitem__
99    TensorUse.__iadd__
100    TensorDef.__setitem__
101  """
102
103  def __init__(self, operand_def: "OperandDef",
104               indices: Sequence[AffineExprDef]):
105    self.operand_def = operand_def
106    self.indices = tuple(indices)
107
108  def to_scalar_expression(self) -> ScalarExpression:
109    return ScalarArg(self.tensor_name).expr()
110
111  @property
112  def tensor_name(self) -> str:
113    name = self.operand_def.name
114    assert name is not None, "TensorDef not registered with an op"
115    return name
116
117  def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
118    # Computes the reduction dims for implicit reductions. Assumes that the rhs
119    # is the expression being reduced and self is being reduced into. Any
120    # indices referenced on the rhs and not in self are considered reduction
121    # dims and will be ordered as encountered on the rhs.
122    rhs_dims = set()
123    lhs_dims = set()
124    rhs.collect_dim_uses(rhs_dims)
125    self.collect_dim_uses(lhs_dims)
126    return rhs_dims - lhs_dims
127
128  def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
129    return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
130
131  def __repr__(self):
132    return (f"{self.operand_def.name}"
133            f"[{', '.join([repr(i) for i in self.indices])}]")
134
135
136class TensorFn(TensorExpression):
137  """Application of a tensor function."""
138
139  def __init__(self, kind: "FunctionKind", name: Optional[str],
140               operand_def: Optional["OperandDef"], type_var: Optional[TypeVar],
141               args: Sequence[TensorExpression]):
142    if bool(name) + bool(operand_def) != 1:
143      raise ValueError("One of 'name', 'operand_def' must be specified")
144    self.name = name
145    self.kind = kind
146    self.operand_def = operand_def
147    self.type_var = type_var
148    self.args = args
149
150  def to_scalar_expression(self) -> ScalarExpression:
151    if self.operand_def:
152      assert self.operand_def.name, "TensorFn not registered with an op"
153    attr_name = self.operand_def.name if self.operand_def else None
154    args = [arg.to_scalar_expression() for arg in self.args]
155    return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
156
157  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
158    super().visit_tensor_exprs(callback)
159    for arg in self.args:
160      arg.visit_tensor_exprs(callback)
161
162  def __repr__(self):
163    name = self.operand_def.name if self.operand_def else self.name
164    return (f"{self.kind.name}.{name}(type_var={self.type_var}, "
165            f"args={', '.join(repr(a) for a in self.args)})")
166
167
168class TensorReduceFn(TensorExpression):
169  """Application of a reduction function.
170
171  This captures the lhs (initial value) separately from the rhs.
172  """
173
174  def __init__(self, reduce_use: "ReduceFnUse",
175               args: Sequence[TensorExpression]):
176    self.reduce_use = reduce_use
177    self.lhs = None  # type: Optional[TensorUse]
178    self.args = args
179
180  def to_scalar_expression(self) -> ScalarExpression:
181    if self.lhs is None:
182      raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
183                       f"bound to its lhs: {self}")
184    full_args = [self.lhs.to_scalar_expression()
185                ] + [arg.to_scalar_expression() for arg in self.args]
186    fn_name = None
187    attr_name = None
188    if self.reduce_use.binary_fn:
189      fn_name = self.reduce_use.binary_fn.fn_name
190    if self.reduce_use.binary_attr:
191      attr_name = self.reduce_use.binary_attr.operand_def.name
192    return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None,
193                    full_args).expr()
194
195  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
196    for arg in self.args:
197      arg.visit_tensor_exprs(callback)
198
199  def __repr__(self):
200    return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
201
202
203class const(TensorExpression):
204  """Returns the given constant floating point or integer value."""
205
206  def __init__(self, value: Any):
207    with _ir.Context():
208      if isinstance(value, float):
209        self.value = str(_ir.FloatAttr.get_f64(float(value)))
210      elif isinstance(value, int):
211        self.value = str(
212            _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
213      else:
214        raise ValueError(f"const requires int or float but got {type(value)}")
215
216  def to_scalar_expression(self) -> ScalarExpression:
217    return ScalarConst(self.value).expr()
218
219  def __repr__(self):
220    return f"const({self.value})"
221
222
223class index(TensorExpression):
224  """Returns the iteration index for a given dimension name.
225
226  Resolves the given dimension name to obtain its position in the iteration
227  domain of the operation.
228  """
229
230  def __init__(self, dim: DimDef):
231    self.dim_def = dim
232    self.dim = -1
233
234  def resolve_dimension_name(self, affine_state: AffineBuildState):
235    self.dim = affine_state.get_dim(self.dim_def.dimname)
236
237  def to_scalar_expression(self) -> ScalarExpression:
238    assert self.dim != -1, "Dimension name not resolved"
239    return ScalarIndex(self.dim).expr()
240
241  def __repr__(self):
242    return f"index({repr(self.dim)})"
243
244
245###############################################################################
246# Function types and function definitions.
247###############################################################################
248
249
250class FunctionKind(Enum):
251  UNARY = 0
252  BINARY = 1
253  TYPE = 2
254
255
256class UnaryFnType:
257  """Unary function.
258
259  A unary function takes one tensor expression and returns the
260  function evaluation result.
261  """
262
263  def __init__(self, fn_name: str):
264    self.fn_name = fn_name
265
266  def __call__(self, arg: TensorExpression) -> "TensorFn":
267    return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
268
269  def __repr__(self):
270    return f"{self.fn_name}"
271
272
273class UnaryFn:
274  """Unary function namespace."""
275  exp = UnaryFnType("exp")
276  log = UnaryFnType("log")
277  abs = UnaryFnType("abs")
278  ceil = UnaryFnType("ceil")
279  floor = UnaryFnType("floor")
280  negf = UnaryFnType("negf")
281
282
283class BinaryFnType:
284  """Binary function.
285
286  A binary function takes two tensor expressions and returns the
287  function evaluation result.
288  """
289
290  def __init__(self, fn_name: str):
291    self.fn_name = fn_name
292
293  def __call__(self, arg0: TensorExpression,
294               arg1: TensorExpression) -> "TensorFn":
295    return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
296
297  def __repr__(self):
298    return f"{self.fn_name}"
299
300
301class BinaryFn:
302  """Binary function namespace.
303
304  As the integer types are signless, signedness is implement by different
305  functions that treat integers as signed or unsigned values.
306
307  Examples:
308  - max -> `arith.MaxSIOp`
309  - max_unsinged -> `arith.MaxUIOp`
310  """
311  add = BinaryFnType("add")
312  sub = BinaryFnType("sub")
313  mul = BinaryFnType("mul")
314  max_signed = BinaryFnType("max_signed")
315  min_signed = BinaryFnType("min_signed")
316  max_unsigned = BinaryFnType("max_unsigned")
317  min_unsigned = BinaryFnType("min_unsigned")
318
319
320class TypeFnType:
321  """Type conversion function.
322
323  A type conversion function takes a target type and a tensor expression and
324  returns the casted tensor expression.
325  """
326
327  def __init__(self, fn_name: str):
328    self.fn_name = fn_name
329
330  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
331    return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
332
333  def __repr__(self):
334    return f"{self.fn_name}"
335
336
337class TypeFn:
338  """Type conversion function namespace.
339
340  As the integer types are signless, signedness is implement by different cast
341  functions that treat integers as signed (`cast_signed`) or unsigned
342  (`cast_unsigned`) values.
343
344  Examples:
345  - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
346  - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
347  """
348  cast_signed = TypeFnType("cast_signed")
349  cast_unsigned = TypeFnType("cast_unsigned")
350
351
352class ReduceFnUse:
353  """Reduction function use.
354
355  A reduction use specifies the reduction function and dimensions.
356  """
357
358  def __init__(self, binary_fn: Optional[BinaryFnType],
359               binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef):
360    if bool(binary_fn) + bool(binary_attr) != 1:
361      raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
362    self.binary_fn = binary_fn
363    self.binary_attr = binary_attr
364    self.reduce_dims = reduce_dims
365
366  def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
367    return TensorReduceFn(self, args)
368
369  def __repr__(self):
370    fn = self.binary_fn if self.binary_fn else self.binary_attr
371    return (
372        f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})")
373
374
375class ReduceFnType:
376  """Reduction function.
377
378  A binary function that reduces its RHS into its LHS.
379  """
380
381  def __init__(self, binary_fn: BinaryFnType):
382    if not isinstance(binary_fn, BinaryFnType):
383      raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
384    self.binary_fn = binary_fn
385
386  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
387    return ReduceFnUse(self.binary_fn, None, *reduce_dims)
388
389  def __repr__(self):
390    return f"reduce_{repr(self.binary_fn)}"
391
392
393class ReduceFn:
394  add = ReduceFnType(BinaryFn.add)
395  mul = ReduceFnType(BinaryFn.mul)
396  max_signed = ReduceFnType(BinaryFn.max_signed)
397  min_signed = ReduceFnType(BinaryFn.min_signed)
398  max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
399  min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
400
401
402###############################################################################
403# Operand definitions.
404###############################################################################
405
406
407class OperandKind(Enum):
408  INPUT_TENSOR = 0
409  SCALAR = 1
410  OUTPUT_TENSOR = 2
411  INDEX_ATTR = 3
412  UNARY_FN_ATTR = 4
413  BINARY_FN_ATTR = 5
414  TYPE_FN_ATTR = 6
415
416
417class OperandDef:
418  """Definition of an operand passed to an operation.
419
420  Keep the meta information of Tensor, Scalar, and Attribute operands and
421  provide the shared registration functionality.
422  """
423
424  def __init__(self,
425               kind: OperandKind,
426               type_var: Optional[TypeVar] = None,
427               size_exprs: Optional[Sequence[AffineExprDef]] = None,
428               index_dims: Optional[Sequence[DimDef]] = None,
429               default_indices: Optional[Sequence[int]] = None,
430               default_fn: Optional[str] = None):
431    if type_var and not isinstance(type_var, TypeVar):
432      raise ValueError(
433          f"OperandDef requires a TypeVar but got {repr(type_var)}")
434    self.owner = None  # type: Optional["LinalgOpDef"]
435    self.type_var = type_var
436    self.size_exprs = size_exprs
437    self.index_dims = index_dims
438    self.default_indices = default_indices
439    self.default_fn = default_fn
440    self.kind = kind
441    self.name = None  # type: Optional[str]
442    self.registered_index = -1  # type: int
443
444  def attach(self, index: int, name: str, owner: "LinalgOpDef"):
445    if self.owner:
446      raise ValueError(f"OperandDef already registered with an op: {self}")
447    self.registered_index = index
448    self.name = name
449    self.owner = owner
450
451  def is_input(self) -> bool:
452    return (self.kind == OperandKind.SCALAR or
453            self.kind == OperandKind.INPUT_TENSOR)
454
455  def is_tensor(self) -> bool:
456    return (self.kind == OperandKind.INPUT_TENSOR or
457            self.kind == OperandKind.OUTPUT_TENSOR)
458
459  def is_attribute(self) -> bool:
460    return (self.kind == OperandKind.INDEX_ATTR or
461            self.kind == OperandKind.UNARY_FN_ATTR or
462            self.kind == OperandKind.BINARY_FN_ATTR or
463            self.kind == OperandKind.TYPE_FN_ATTR)
464
465  def __hash__(self):
466    return hash(id(self))
467
468  def __repr__(self):
469    return (f"{self.name}:OperandDef(kind={self.kind.name}, "
470            f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, "
471            f"index_dims={self.index_dims}, "
472            f"default_indices={self.default_indices}, "
473            f"default_fn={self.default_fn})")
474
475
476class TensorDef:
477  """Tensor operand definition.
478
479  Tensor operands are indexed using the associated indexing_map when forwarded
480  to the body of the structured op. A unique name identifies the tensor operands
481  and an index determines their position in the operation's parameter list. A
482  tensor definition takes type, a shape, and an optional flag to mark output
483  tensors. Additionally, a tuple of index dimensions may be used to map the
484  tensor to the loop dimensions of the operation. This mapping is needed to
485  compute the indexing map of shape-only tensors that have no uses.
486  """
487
488  def __init__(self,
489               type_var: TypeVar,
490               *shape: AffineExprDef,
491               index_dims: Optional[Sequence[DimDef]] = None,
492               output: bool = False):
493    if index_dims and len(shape) != len(index_dims):
494      raise ValueError(f"Expected the shape rank {len(shape)} to match the "
495                       f"number of index_dims {len(index_dims)}")
496    if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
497      raise ValueError(f"TensorDef requires index dims of type DimDef but "
498                       f"got {index_dims}")
499    kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
500    self.operand_def = OperandDef(
501        kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
502
503  def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
504    assert self.operand_def.owner, "TensorDef is not registered with an op"
505    state = AffineBuildState(
506        global_state=self.operand_def.owner._affine_state,
507        allow_new_symbols=False)
508    if not isinstance(dims, tuple):
509      dims = (dims,)  # Handle single subscript case.
510    # Special case: (None) is a 0d-scalar use.
511    if dims == (None,):
512      dims = ()
513
514    exprs = []
515    for expr_def in dims:
516      if not isinstance(expr_def, AffineExprDef):
517        raise KeyError(
518            "A TensorDef can only be subscripted by a tuple of affine dims")
519      exprs.append(expr_def)
520    return TensorUse(self.operand_def, exprs)
521
522  def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
523    """Creates a new 1:1 comprehension by binding this tensor to an expression.
524
525    Note that due to the way assignment works in Python, we have to capture
526    direct assignment as a setitem on the TensorDef.
527    """
528    if not isinstance(value, TensorExpression):
529      raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. "
530                       f"Got: {repr(value)}")
531    use = self[dims]
532    comp = Comprehension((use, value))
533    self.operand_def.owner.comprehensions.append(comp)
534
535
536class ScalarDef(TensorExpression):
537  """Scalar operand definition.
538
539  Scalar operands are forwarded to the body of the structured op as they are.
540  A unique name identifies the scalars and an index determines their position in
541  the operation's parameter list.
542  """
543
544  def __init__(self, type_var: TypeVar):
545    self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
546
547  @property
548  def scalar_name(self) -> str:
549    name = self.operand_def.name
550    assert name is not None, "ScalarDef not registered with an op"
551    return name
552
553  def to_scalar_expression(self) -> ScalarExpression:
554    return ScalarArg(self.scalar_name).expr()
555
556
557class IndexAttrDef:
558  """Index attribute definition.
559
560  Index attributes provide a way to define and set symbols that can be used in
561  indexing expressions. Every attribute specifies a tuple of symbols that at
562  compile-time are replaced by integer values as well as their default values.
563  """
564
565  def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
566    if any(not isinstance(size, SymbolDef) for size in sizes):
567      raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef "
568                       f"but got {sizes}")
569    if any(not isinstance(default_val, int) for default_val in default):
570      raise ValueError(f"IndexAttrDef requires default values of type int "
571                       f"but got {default}")
572    if len(sizes) != len(default):
573      raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
574                       f"but got {len(default)}")
575    self.operand_def = OperandDef(
576        OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
577
578
579class UnaryFnAttrDef:
580  """Unary function attribute definition.
581
582  Unary function attributes provide a way to make the arithmetic computation
583  parametrizable. Every attribute specifies a default unary function
584  that may be overwritten at operation instantiation time.
585  """
586
587  def __init__(self, default: "UnaryFnType"):
588    if not isinstance(default, UnaryFnType):
589      raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType "
590                       f"but got {default}")
591    self.operand_def = OperandDef(
592        OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name)
593
594  def __call__(self, arg: TensorExpression) -> TensorFn:
595    return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
596
597
598class BinaryFnAttrDef:
599  """Binary function attribute definition.
600
601  Binary function attributes provide a way to make the arithmetic computation
602  parametrizable. Every attribute specifies a default binary function
603  that may be overwritten at operation instantiation time.
604  """
605
606  def __init__(self, default: "BinaryFnType"):
607    if not isinstance(default, BinaryFnType):
608      raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType "
609                       f"but got {default}")
610    self.operand_def = OperandDef(
611        OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name)
612
613  def __call__(self, arg0: TensorExpression,
614               arg1: TensorExpression) -> TensorFn:
615    return TensorFn(FunctionKind.BINARY, None, self.operand_def, None,
616                    [arg0, arg1])
617
618  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
619    return ReduceFnUse(None, self, *reduce_dims)
620
621
622class TypeFnAttrDef:
623  """Type conversion function attribute definition.
624
625  Type conversion function attributes provide a way to make type conversions
626  parameterizable. Every attribute specifies a default type conversion function
627  that may be overwritten at operation instantiation time.
628  """
629
630  def __init__(self, default: "TypeFnType"):
631    if not isinstance(default, TypeFnType):
632      raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType "
633                       f"but got {default}")
634    self.operand_def = OperandDef(
635        OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
636
637  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
638    return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
639
640
641###############################################################################
642# Operation definition.
643###############################################################################
644
645
646class Comprehension:
647  """Represents a single comprehension."""
648
649  def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
650    self.definitions = list()  # List[TensorUse]
651    self.values = list()  # List[TensorExpression]
652
653    # Find the lhs to reduction rhs.
654    for assign, value in bindings:
655      if isinstance(value, TensorReduceFn):
656        if value.lhs:
657          raise ValueError(f"Reduction expression already assigns: {value}")
658        value.lhs = assign
659      self.definitions.append(assign)
660      self.values.append(value)
661
662  @property
663  def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
664    """Gets the reduction dims for the comprehension or None."""
665    result = set()
666    for use in self.values:
667      if isinstance(use, TensorReduceFn):
668        result.add(use.reduce_use.reduce_dims)
669      else:
670        result.add(tuple())
671    return result
672
673  def __repr__(self):
674    if len(self.definitions) > 1:
675      defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
676      values_repr = f"({', '.join(repr(v) for v in self.values)})"
677    else:
678      defs_repr = f"{repr(self.definitions[0])}"
679      values_repr = f"{repr(self.values[0])}"
680
681    return f"{defs_repr} = {values_repr}"
682
683
684class OpInterfaceDef:
685  """An interface that an op implements."""
686
687  def __init__(self, cpp_name: str):
688    self.cpp_name = cpp_name
689
690
691ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
692ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
693FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
694
695
696class OpDefinitionDef:
697  """A method that an op implements."""
698
699  def __init__(self, def_name: str):
700    self.def_name = def_name
701
702
703Canonicalizer = OpDefinitionDef("hasCanonicalizer")
704
705
706class OpMetadataDef(YAMLObject):
707  """Metadata about the op (generally not behavior impacting)."""
708  yaml_tag = "!LinalgOpMetadata"
709
710  def __init__(self, name: str, cpp_class_name: Optional[str],
711               doc: Optional[str]):
712    self.name = name
713    self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
714    self.doc = doc
715    self.implements = []  # type: List[OpInterfaceDef]
716    self.defines = []  # type: List[OpDefinitionsDef]
717
718  def to_yaml_custom_dict(self):
719    d = dict(
720        name=self.name,
721        cpp_class_name=self.cpp_class_name,
722        doc=self.doc,
723    )
724    if self.implements:
725      d["implements"] = [intr.cpp_name for intr in self.implements]
726    if self.defines:
727      d["defines"] = [defi.def_name for defi in self.defines]
728    return d
729
730
731class LinalgOpDef:
732  """Definition of a linalg op."""
733
734  def __init__(self,
735               name: str,
736               cpp_class_name: Optional[str] = None,
737               doc: Optional[str] = None):
738    self.metadata = OpMetadataDef(
739        name=name, cpp_class_name=cpp_class_name, doc=doc)
740    self.registered_operands = dict()  # type: Dict[str, OperandDef]
741    self.domain = list()  # type: List[DimDef]
742    self.comprehensions = list()  # type: List[Comprehension]
743    self._affine_state = AffineBuildState()
744
745  def add_operand(self, name: str, operand: OperandDef):
746    """Registers an operand."""
747    if name in self.registered_operands:
748      raise ValueError(f"The operand {name} is already registered "
749                       f"to {self.registered_operands['name']}")
750    structured_op_methods = [
751        "inputs", "outputs", "result_tensors", "region", "iterator_types",
752        "indexing_maps", "getRegionBuilder", "getLibraryCallName"
753    ]
754    if operand.is_attribute() and name in structured_op_methods:
755      raise ValueError(f"The attribute name {name} conflicts with a structured "
756                       f"op method name")
757    # Ensure output tensors are registered after input tensors and scalars and
758    # attributes are registered after all other operand types.
759    if operand.is_input() and any(
760        not op_def.is_input() for op_def in self.registered_operands.values()):
761      raise ValueError(f"Input {name} registered after an output or attribute")
762    if operand.kind == OperandKind.OUTPUT_TENSOR and any(
763        op_def.is_attribute() for op_def in self.registered_operands.values()):
764      raise ValueError(f"Output {name} registered after an attribute")
765    operand.attach(len(self.registered_operands), name, self)
766    self.registered_operands[name] = operand
767
768  def __repr__(self):
769    lines = [
770        f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"
771    ]
772    for name, operand in self.registered_operands.items():
773      lines.append(f"  {operand}")
774    if self.comprehensions:
775      lines[-1] += " {"
776      for comprehension in self.comprehensions:
777        lines.append(f"    {comprehension}")
778      lines.append("}")
779    return "\n".join(lines)
780