1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5"""Experimental MLIR-PyTACO with sparse tensor support.
6
7See http://tensor-compiler.org/ for TACO tensor compiler.
8
9This module implements the Python classes for PyTACO index notation. These
10include classes for data types, tensor dimension formats (aka mode formats),
11tensor dimension orderings (aka mode ordering), tensor storage formats, and
12tensors.
13
14The PyTACO API doesn't follow the naming conversion required by the style guide
15for this module. As such, we first implement the supporting classes and routines
16following the style guide, and then define the type aliases and constants to
17support the PyTACO API in the pytaco_api module.
18"""
19
20from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
21
22import abc
23import ctypes
24import dataclasses
25import enum
26import numpy as np
27import functools
28import operator
29import os
30import threading
31
32# Import MLIR related modules.
33from mlir import execution_engine
34from mlir import ir
35from mlir import runtime
36from mlir.dialects import arith
37from mlir.dialects import bufferization
38from mlir.dialects import builtin
39from mlir.dialects import func
40from mlir.dialects import linalg
41from mlir.dialects import sparse_tensor
42from mlir.dialects.linalg.opdsl import lang
43
44from . import mlir_pytaco_utils as utils
45
46# TACO naming prefixes.
47_TACO_INDEX_PREFIX = "i"
48_TACO_TENSOR_PREFIX = "A"
49
50# Bitwidths for pointers and indices.
51_POINTER_BIT_WIDTH = 0
52_INDEX_BIT_WIDTH = 0
53# The entry point to the JIT compiled program.
54_ENTRY_NAME = "main"
55
56# Type aliases for type annotation.
57_UnaryOp = Callable[[Any], Any]
58_BinaryOp = Callable[[Any, Any], Any]
59_ExprVisitor = Callable[..., None]
60_ExprInfoDict = Dict["IndexExpr", "_ExprInfo"]
61_LogicalOp = Callable[[bool, bool], bool]
62_ModeFormatOp = Callable[["ModeFormat", "ModeFormat"], "ModeFormat"]
63_SubtreeLeafChecker = Optional[Callable[..., bool]]
64
65
66class Type(enum.Enum):
67  """The data types supported by TACO.
68
69  We use numpy data types to implement the enum data types.
70  """
71  INT8 = np.int8
72  INT16 = np.int16
73  INT32 = np.int32
74  INT64 = np.int64
75  FLOAT16 = np.float16
76  FLOAT32 = np.float32
77  FLOAT64 = np.float64
78  COMPLEX64 = np.complex64
79  COMPLEX128 = np.complex128
80
81
82# All floating point type enums.
83_FLOAT_TYPES = (Type.FLOAT16, Type.FLOAT32, Type.FLOAT64)
84# All integral type enums.
85_INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64)
86# All complex type enums.
87_COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128)
88# Type alias for any numpy type used to implement the runtime support for the
89# enum data types.
90_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16,
91                        np.float32, np.float64, np.complex64, np.complex128]
92
93
94@dataclasses.dataclass(frozen=True)
95class DType:
96  """The data type class.
97
98  We support the TACO API dtype class with an alias of this class.
99
100  The following methods are defined by the TACO API:
101    is_float: Returns whether the data type represents a floating point value.
102    is_int:   Returns whether the data type represents an integral value.
103
104  Attributes:
105    kind: A Type enum representing the data type.
106    value: The numpy data type for the TACO data type.
107  """
108  kind: Type = Type.FLOAT32
109
110  def is_float(self) -> bool:
111    """Returns whether the data type represents a floating point value."""
112    return self.kind in _FLOAT_TYPES
113
114  def is_int(self) -> bool:
115    """Returns whether the data type represents an integral value."""
116    return self.kind in _INT_TYPES
117
118  def is_complex(self) -> bool:
119    """Returns whether the data type represents a complex value."""
120    return self.kind in _COMPLEX_TYPES
121
122  @property
123  def value(self) -> _AnyRuntimeType:
124    """Returns the numpy dtype for the data type."""
125    return self.kind.value
126
127
128def _dtype_to_mlir_str(dtype: DType) -> str:
129  """Returns the MLIR string for the given dtype."""
130  dtype_to_str = {
131      Type.INT16: "i8",
132      Type.INT16: "i16",
133      Type.INT32: "i32",
134      Type.INT64: "i64",
135      Type.FLOAT16: "f16",
136      Type.FLOAT32: "f32",
137      Type.FLOAT64: "f64",
138      Type.COMPLEX64: "complex<f32>",
139      Type.COMPLEX128: "complex<f64>"
140  }
141  return dtype_to_str[dtype.kind]
142
143
144def _nptype_to_taco_type(ty: np.dtype) -> DType:
145  """Returns the TACO type for the given numpy type."""
146  nptype_to_dtype = {
147      np.int8: Type.INT8,
148      np.int16: Type.INT16,
149      np.int32: Type.INT32,
150      np.int64: Type.INT64,
151      np.float16: Type.FLOAT16,
152      np.float32: Type.FLOAT32,
153      np.float64: Type.FLOAT64,
154      np.complex64: Type.COMPLEX64,
155      np.complex128: Type.COMPLEX128
156  }
157  return DType(nptype_to_dtype[ty])
158
159
160def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
161  """Returns the MLIR type corresponding to the given TACO type."""
162  dtype_to_irtype = {
163      Type.INT8: ir.IntegerType.get_signless(8),
164      Type.INT16: ir.IntegerType.get_signless(16),
165      Type.INT32: ir.IntegerType.get_signless(32),
166      Type.INT64: ir.IntegerType.get_signless(64),
167      Type.FLOAT16: ir.F16Type.get(),
168      Type.FLOAT32: ir.F32Type.get(),
169      Type.FLOAT64: ir.F64Type.get(),
170      Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),
171      Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get())
172  }
173  return dtype_to_irtype[dtype.kind]
174
175def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer:
176  """Returns the ctype pointer for the given numpy array."""
177  return ctypes.pointer(
178      ctypes.pointer(runtime.get_ranked_memref_descriptor(array)))
179
180
181class ModeFormat(enum.Enum):
182  """The tensor dimension storage format class.
183
184  We support the TACO API mode_format class with an alias of this class.
185
186  In TACO, a tensor dimension is called a mode and the storage format for a
187  tensor dimension is called a mode format.
188  """
189  DENSE = sparse_tensor.DimLevelType.dense
190  COMPRESSED = sparse_tensor.DimLevelType.compressed
191
192
193def _mode_format_operation(a: ModeFormat, b: ModeFormat,
194                           op: _LogicalOp) -> ModeFormat:
195  """Implements the given operator on ModeFormat."""
196  return (ModeFormat.COMPRESSED
197          if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED) else
198          ModeFormat.DENSE)
199
200
201def _mode_format_estimator(op: _BinaryOp) -> _ModeFormatOp:
202  """Produces a ModeFormat operator for the given binary operator.
203
204  The ModeFormat operator is used as a heuristic to derive the destination
205  dimension sparsity from the source dimension sparsity. In particular, if the
206  binary operator produces a disjunction of the zero values from its source
207  operands, such as the MUL operator, we return a ModeFormat operator that
208  uses operator.or_. That is, we estimate that a dimension for the MUL
209  operation result to be sparse if either of its source operands is sparse.
210
211  On the other hand, if the binary operator produces a conjunction of the
212  zero values from its source operands, such as the ADD operator, we return
213  a ModeFormat operator that uses operator.and_. In this case, we estimate
214  that a dimension for the ADD operation result to be sparse if both of its
215  source operands are sparse.
216
217  Args:
218    op: A _BinaryOp object representing a supporting operator on tensors.
219
220  Returns:
221    A ModeFormatOp for estimating the destination dimension sparsity from
222    the source dimension sparsity.
223  """
224  conjunction = functools.partial(_mode_format_operation, op=operator.and_)
225  disjunction = functools.partial(_mode_format_operation, op=operator.or_)
226  return conjunction if op(0, 1) != 0 else disjunction
227
228
229def _all_instance_of(collection: Iterable, cls: Any) -> bool:
230  """Returns true if all elements of the iterable is an instance of cls."""
231  return all(isinstance(e, cls) for e in collection)
232
233
234def _identity_ordering(rank: int) -> List[int]:
235  """Returns the identity ordering for tensor of given rank."""
236  return list(range(rank))
237
238
239@dataclasses.dataclass(frozen=True)
240class ModeOrdering:
241  """The tensor dimension ordering class.
242
243  We support the TACO API mode_ordering class with an alias of this class.
244
245  Attributes:
246    ordering: A list of integers representing the ordering of the tensor
247      dimensions.
248  """
249  ordering: List[int]
250
251  def __post_init__(self) -> None:
252    """Verifies the value in ordering.
253
254    Raises:
255       ValueError: If ordering is not a list of integers.
256    """
257    if (not isinstance(self.ordering, list) or
258        not _all_instance_of(self.ordering, int)):
259      raise ValueError("Ordering must be a list of integers: "
260                       f"{self.ordering}")
261    # Check that ordering is a permutation of the dimension numbers.
262    if sorted(self.ordering) != _identity_ordering(self.rank()):
263      raise ValueError(f"Invalid ordering: {self.ordering} != "
264                       f"permutation{_identity_ordering(self.rank())}.")
265
266  def rank(self) -> int:
267    """Returns the number of dimensions represented by the ordering."""
268    return len(self.ordering)
269
270
271@dataclasses.dataclass(frozen=True)
272class ModeFormatPack:
273  """The tensor dimension format class.
274
275  We support the TACO API mode_format_pack class with an alias of this class.
276
277  The storage format of a tensor contains one mode_format for each tensor
278  dimension.
279
280  Attributes:
281    formats: A list of ModeFormat representing the storage format for each of
282      the tensor dimension.
283  """
284  formats: List[ModeFormat]
285
286  def __post_init__(self) -> None:
287    """Verifies the value in formats.
288
289    Raises:
290       ValueError: If formats is not a list of ModeFormats.
291    """
292    if (not isinstance(self.formats, list) or
293        not _all_instance_of(self.formats, ModeFormat)):
294      raise ValueError("Formats must be a list of ModeFormat: "
295                       f"{self.formats}")
296
297  def rank(self) -> int:
298    """Returns the number of dimensions represented by the format pack."""
299    return len(self.formats)
300
301
302@dataclasses.dataclass
303class Format:
304  """The tensor format class defined by the TACO API.
305
306  Attributes:
307    format_pack: A ModeFormatPack representing the storage format for the tensor
308      dimensions.
309    ordering: A ModeOrdering representing the tensor dimension ordering in the
310      storage.
311  """
312  format_pack: ModeFormatPack
313  ordering: Optional[ModeOrdering] = None
314
315  def __post_init__(self) -> None:
316    """Verifies and fixes up the values in format_pack and ordering.
317
318    Verifies and fixes up the values in format_pack and ordering to supports the
319    initializer syntax defined by the TACO API. If format_pack is a list of
320    ModeFormat, replaces it with ModeFormatPack constructed from the list. If
321    ordering is not provided, set ordering to the natural ordering for the rank
322    corresponding to format_pack.
323
324    Raises:
325       ValueError: If format_pack is not an instance of ModeFormatPack or if
326         ordering is not an instance of ModeOrdering.
327    """
328    if isinstance(self.format_pack, list):
329      if not _all_instance_of(self.format_pack, ModeFormat):
330        raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}")
331      self.format_pack = ModeFormatPack(self.format_pack)
332    if not isinstance(self.format_pack, ModeFormatPack):
333      raise ValueError(f"Expected ModeFormatpack: {self.format_pack}")
334
335    if self.ordering is None:
336      self.ordering = ModeOrdering(list(range(self.rank())))
337    if isinstance(self.ordering, list):
338      if not _all_instance_of(self.ordering, int):
339        raise ValueError(f"Expected a list of integer: {self.ordering}")
340      self.ordering = ModeOrdering(self.ordering)
341    if not isinstance(self.ordering, ModeOrdering):
342      raise ValueError(f"Expected ModeOrdering: {self.ordering}")
343
344    if self.format_pack.rank() != self.ordering.rank():
345      raise ValueError("Inconsistent ModeFormatPack and ModeOrdering: "
346                       f"len({self.format_pack}) != "
347                       f"len({self.ordering})")
348
349  def rank(self) -> int:
350    """Returns the number of dimensions represented by the format."""
351    return self.format_pack.rank()
352
353  def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]:
354    """Constructs the numpy arrays for the permutation and sparsity."""
355    perm = np.array(self.ordering.ordering, dtype=np.ulonglong)
356    a = [0 if s == ModeFormat.DENSE else 1 for s in self.format_pack.formats]
357    sparse = np.array(a, dtype=np.uint8)
358    return (perm, sparse)
359
360  def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]:
361    """Constructs the MLIR attributes for the tensor format."""
362    order = (
363        range(self.rank()) if
364        (self.ordering is None) else self.ordering.ordering)
365    mlir_storage_format = [f.value for f in self.format_pack.formats]
366    return sparse_tensor.EncodingAttr.get(mlir_storage_format,
367                                          ir.AffineMap.get_permutation(order),
368                                          _POINTER_BIT_WIDTH, _INDEX_BIT_WIDTH)
369
370
371def _make_format(formats: List[ModeFormat],
372                 ordering: Optional[List[int]] = None) -> Format:
373  """Constructs a format from a list of ModeFormat and an optional ordering.
374
375  Args:
376    formats: A list of ModeFormat, one for each dimension of a tensor.
377    ordering: An optional list of integer, for the ordering of the tensor
378      dimensions. When an ordering is not given, the identity ordering is used.
379
380  Returns:
381    A tensor format object.
382
383  Raises:
384    ValueError: If formats is not a list of ModeFormat or the length of formats
385      is not consistent with the len of ordering.
386  """
387  ordering = ordering or _identity_ordering(len(formats))
388  return Format(ModeFormatPack(formats), ModeOrdering(ordering))
389
390
391class IndexExpr(abc.ABC):
392  """The index notation base class.
393
394  We support the TACO API index_expression class with an alias of this class.
395  """
396
397  def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
398    """Verifies the RHS operand and returns a binary expression.
399
400    Args:
401      rhs: The RHS of the binary operation, which could be any Python object
402        from user inputs.
403      op: A _BinaryOp object representing the binary operator.
404
405    Raises:
406      ValueError: If rhs is not an IndexExpr.
407    """
408    if not isinstance(rhs, IndexExpr):
409      raise ValueError(f"Expected IndexExpr: {rhs}")
410    return _BinaryExpr(op, self, rhs)
411
412  def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
413    """Build a unary expression.
414
415    Args:
416      op: A _UnaryOp object representing the unary operation.
417    """
418    return _UnaryExpr(op, self)
419
420  def __add__(self, rhs) -> "_BinaryExpr":
421    """Defines the operator +.
422
423    Args:
424      rhs: The value being added, which could be any Python object from user
425        inputs.
426
427    Returns:
428      A _BinaryExpr object representing the operation.
429
430    Raises:
431      ValueError: If rhs is not an IndexExpr.
432    """
433    return self._verify_operand_and_build_expr(rhs, operator.add)
434
435  def __mul__(self, rhs) -> "_BinaryExpr":
436    """Defines the operator *.
437
438    Args:
439      rhs: The value being multiplied, which could be any Python object from
440        user inputs.
441
442    Returns:
443      A _BinaryExpr object representing the operation.
444
445    Raises:
446      ValueError: If rhs is not an IndexExpr.
447    """
448    return self._verify_operand_and_build_expr(rhs, operator.mul)
449
450  def __abs__(self) -> "_UnaryExpr":
451    """Defines the operator abs.
452
453    Returns:
454      A _UnaryExpr object representing the operation.
455    """
456    return self._build_unary_expr(operator.abs)
457
458  def __neg__(self) -> "_UnaryExpr":
459    """Defines the operator neg.
460
461    Returns:
462      A _UnaryExpr object representing the operation.
463    """
464    return self._build_unary_expr(operator.neg)
465
466  def __sub__(self, rhs) -> "_BinaryExpr":
467    """Defines the operator -.
468
469    Args:
470      rhs: The value being subtracted, which could be any Python object from
471        user inputs.
472
473    Returns:
474      A _BinaryExpr object representing the operation.
475
476    Raises:
477      ValueError: If rhs is not an IndexExpr.
478    """
479    return self._verify_operand_and_build_expr(rhs, operator.sub)
480
481  @abc.abstractmethod
482  def _visit(self,
483             func: _ExprVisitor,
484             args,
485             *,
486             leaf_checker: _SubtreeLeafChecker = None) -> None:
487    """A post-order visitor.
488
489    Args:
490      func: A callable applied to each node in the expression tree.
491      args: The variable-length arguments passed to the callable. These
492        arguments are grouped as an iterable and will be unpacked before passing
493        to the callable. This is to enable the keyword argument only syntax
494        after this argument.
495      leaf_checker: A callable object to identify nodes that should be treated
496        as leaf nodes to support partial tree visiting.
497    """
498    pass
499
500  @abc.abstractmethod
501  def _emit_expression(
502      self,
503      expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
504      expr_to_info: _ExprInfoDict,
505  ) -> lang.ScalarExpression:
506    """Emits MLIR for the expression tree.
507
508    Args:
509      expr_to_opnd: A dictionary for looking up structured op input operands for
510        the input nodes of the structured op.
511      expr_to_info: A dictionary for looking up code generation information for
512        expressions.
513
514    Returns:
515      A linalg dialect ScalarExpression for the expression.
516    """
517    pass
518
519  @abc.abstractmethod
520  def dtype(self) -> DType:
521    """Returns the data type for the result of the expression."""
522    pass
523
524  def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
525    """Emits a structured op in the linalg dialect for the expression tree.
526
527    We define a DefineOpcallable in the domain specific language for the linalg
528    dialect and execute the callable to generate the structured op. Self is the
529    root of the expression tree for the structured op.
530
531    Args:
532      expr_to_info: A dictionary for looking up code generation information for
533        expressions.
534    """
535    op_info = expr_to_info[self].structop_info
536    op_name = op_info.dst_name
537    op_def = lang.LinalgOpDef(name=op_name)
538    op_callable = lang.DefinedOpCallable(op_name, op_def)
539
540    # Collect the input expression nodes for the structured op.
541    expr_inputs = []
542    self._visit(
543        _gather_structured_op_input,
544        (self, expr_to_info, expr_inputs),
545        leaf_checker=_is_structured_op_leaf,
546    )
547
548    # Create a linalg structured op operand for each input expression node and
549    # build a dictionary for looking up the information.
550    expr_to_input_opnd = {
551        e: _emit_structured_op_input(e, expr_to_info, op_def)
552        for e in expr_inputs
553    }
554
555    # Emit the expression tree, which produces the value assigned to the
556    # destination tensor.
557    value = self._emit_expression(expr_to_input_opnd, expr_to_info)
558    # Emit the structured op representation for the destination tensor.
559    dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
560                             lang.OperandKind.OUTPUT_TENSOR)
561    dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
562    dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
563
564    expr_info = expr_to_info[self]
565    # If the structured op reduces some indices, explicitly represent the
566    # reduction. This is done by generating a ReduceFn for the dimensions being
567    # reduced in the linalg dialect and calling the function with the value
568    # being reduced. We only support add reduction currently.
569    if expr_info.reduce_indices:
570      reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
571      value = lang.ReduceFn.add[reduce_dims](value)
572
573    # Emit the assignment as a comprehension in the linalg dialect.
574    comp = lang.Comprehension((dst_use, value))
575    op_def.comprehensions.append(comp)
576
577    # The structured op in the linalg dialect requires an explicit
578    # initialization for the destination tensor. Emit MLIR to initialize the
579    # destination tensor.
580    init = op_info.emit_tensor_init()
581
582    # Collect MLIR values for the linalg input operands, with the assumption
583    # that dictionary preserves the insertion order.
584    args = [
585        expr_to_info[expr].mlir_value
586        for expr, opnd in expr_to_input_opnd.items()
587    ]
588    # Execute the DefineOpcallable object for the linalg dialect operation to
589    # emit MLIR for the linalg structured op.
590    expr_info.mlir_value = op_callable(*args, outs=[init])
591
592  def _identify_structured_ops(
593      self,
594      expr_to_info: _ExprInfoDict,
595      dst: "Tensor",
596      dst_indices: Tuple["IndexVar", ...],
597  ) -> List["IndexExpr"]:
598    """Returns expression nodes for the roots of the identified structured ops.
599
600    A structured op in the linalg dialect only supports reduction performed on
601    the whole expression. If the expression tree contains reduction that are
602    performed on part of the expression tree, the expression tree needs to be
603    implemented with multiple structured ops. This routine identifies all the
604    expression nodes that contain reduction as the root of structured ops in the
605    linalg dialect.
606
607    Args:
608      expr_to_info: A dictionary for looking up code generation information for
609        expressions.
610      dst: A destination Tensor that accepts the value of the expression tree.
611      dst_indices: The indices used by the destination index expression.
612
613    Returns:
614      An ordered list of IndexExpr for the root expressions of the structured
615      ops, where child expressions go before parent expressions that use their
616      results.
617    """
618    reduce_indices = tuple(
619        set(expr_to_info[self].src_indices) - set(dst_indices))
620    for reduce_index in reduce_indices:
621      _mark_structured_op_root(self, reduce_index, expr_to_info)
622
623    self._visit(_accumulate_reduce_indices, (expr_to_info,))
624    structop_roots = []
625    self._visit(_gather_structured_op, (expr_to_info, structop_roots))
626
627    # Handle the root of the top level expression.
628    if not structop_roots or structop_roots[-1] != self:
629      # The top level expression is not a reduction. Add the top level
630      # expression as a structured op root.
631      structop_roots.append(self)
632
633    # Use user specified information for the destination tensor to build an
634    # _StructOpInfo for the top level expression.
635    expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
636                                                     tuple(dst.shape),
637                                                     dst.dtype, dst.name,
638                                                     dst.format)
639
640    return structop_roots
641
642  def _validate_and_collect_expr_info(
643      self,
644      dst: "Tensor",
645      dst_indices: Tuple["IndexVar", ...],
646  ) -> _ExprInfoDict:
647    """Propagates expression information for validation.
648
649    Propagates the indices used by child expression nodes to parent expression
650    nodes. Also collects and validates the sizes for the dimensions
651    corresponding to the indices.
652
653    Args:
654      dst: A destination Tensor that accepts the value of the expression tree.
655      dst_indices: The indices used by the destination index expression.
656
657    Raises:
658      ValueError if there is any inconsistency in indices or dimensional
659      values.
660
661    Returns:
662      A dictionary of (IndexExpr, _ExprInfo).
663    """
664    expr_to_info = {}
665    # Validate the expression tree and construct expression information.
666    self._visit(_validate_and_collect_expr_info, (expr_to_info,))
667
668    # Validate the destination dimension information.
669    info = expr_to_info[self]
670    index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
671    for i, d, in zip(dst_indices, dst.shape):
672      if i not in index_to_dim_info:
673        raise ValueError("Destination IndexVar not used in the "
674                         f"source expression: {i}")
675      else:
676        if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1:
677          raise ValueError(f"Inconsistent destination dimension for {i}: "
678                           f"{d} vs {index_to_dim_info[i].dim}")
679
680    return expr_to_info
681
682  def _emit_assignment(
683      self,
684      module: ir.Module,
685      dst: "Tensor",
686      dst_indices: Tuple["IndexVar", ...],
687      expr_to_info: _ExprInfoDict,
688      input_accesses: List["Access"],
689  ) -> None:
690    """Emits an MLIR function for assigning the expression to a tensor."""
691    input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
692
693    # Build the kernel for the operations.
694    with ir.InsertionPoint(module.body):
695
696      @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
697      def linalg_funcop(*args):
698        # Set up the mapping from the Access nodes to their MLIR values.
699        for e, mlir in zip(input_accesses, args):
700          expr_to_info[e].mlir_value = mlir
701
702        # Emit structured ops in the linalg dialect to implement the assignment.
703        for structop_root in self._identify_structured_ops(
704            expr_to_info, dst, dst_indices):
705          structop_root._emit_structured_op(expr_to_info)
706          dst._record_stats(expr_to_info[structop_root].structop_info)
707
708        # The function returns the MLIR value of the root expression.
709        return expr_to_info[self].mlir_value
710
711      linalg_funcop.func_op.attributes[
712          "llvm.emit_c_interface"] = ir.UnitAttr.get()
713
714  def get_input_accesses(self) -> List["Access"]:
715    """Compute the list of input accesses for the expression."""
716    input_accesses = []
717    self._visit(_gather_input_accesses_index_vars, (input_accesses,))
718    return input_accesses
719
720  def compile(
721      self,
722      dst: "Tensor",
723      dst_indices: Tuple["IndexVar", ...],
724  ) -> execution_engine.ExecutionEngine:
725    """Compiles the tensor assignment dst[dst_indices] = expression.
726
727    Args:
728      dst: The destination tensor.
729      dst_indices: The tuple of IndexVar used to access the destination tensor.
730
731    Returns:
732      The execution engine for the tensor assignment.
733
734    Raises:
735      ValueError: If the expression is not proper or not supported.
736    """
737    expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
738    input_accesses = self.get_input_accesses()
739
740    # Build and compile the module to produce the execution engine.
741    with ir.Context(), ir.Location.unknown():
742      module = ir.Module.create()
743      self._emit_assignment(module, dst, dst_indices, expr_to_info,
744                            input_accesses)
745      engine = utils.compile_and_build_engine(module)
746
747    return engine
748
749
750class _AtomicCounter:
751  """An atomic counter."""
752
753  def __init__(self):
754    self._counter = 0
755    self._counter_lock = threading.Lock()
756
757  def increment(self) -> int:
758    """Increments the counter by one and returns the old value."""
759    old_value = self._counter
760    with self._counter_lock:
761      self._counter = self._counter + 1
762    return old_value
763
764
765class IndexVar(IndexExpr):
766  """The tensor index class.
767
768  We support the TACO API index_var class with an alias of this class.
769
770  An IndexVar object represents an index variable in tensor index notation.
771
772  Attributes:
773    name: A unique string name of the IndexVar.
774  """
775  _counter = _AtomicCounter()
776
777  def __init__(self):
778    id = self._counter.increment()
779    self._name = f"{_TACO_INDEX_PREFIX}{id}"
780
781  def __repr__(self) -> str:
782    return f"IndexVar(name={repr(self._name)})"
783
784  @property
785  def name(self) -> str:
786    """Returns the name of the IndexVar."""
787    return self._name
788
789  def _visit(self,
790             func: _ExprVisitor,
791             args,
792             *,
793             leaf_checker: _SubtreeLeafChecker = None) -> None:
794    """A post-order visitor."""
795    if leaf_checker:
796      assert leaf_checker(self, *args)
797    func(self, *args)
798
799  def _emit_expression(
800      self,
801      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
802      expr_to_info: _ExprInfoDict,
803  ) -> lang.ScalarExpression:
804    """Emits a index value casted to the data type of the tensor expression."""
805    dim = getattr(lang.D, self.name)
806    index = lang.index(dim)
807    int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index)
808    return lang.TypeFn.cast_unsigned(lang.T, int_value)
809
810  def dtype(self) -> DType:
811    """Returns the data type for the index value.
812
813    This is unreachable for IndexVar.
814    """
815    assert 0
816
817
818def get_index_vars(n: int) -> List[IndexVar]:
819  """Returns a list of n IndexVar.
820
821  This routine is defined by the TACO API.
822
823  Args:
824    n: An integer representing the number of IndexVar to get.
825
826  Returns:
827    A list of IndexVar.
828
829  Raises:
830    ValueError: if n is not a positive integer.
831  """
832  if not isinstance(n, int) or n <= 0:
833    raise ValueError(f"Expected an integer: {n}.")
834  # If lock contention ever becomes an issue, we could implement a bulk getter
835  # that returns a range by only claiming the lock once.
836  return [IndexVar() for i in range(n)]
837
838
839def _mlir_symbols_from_index_vars(
840    index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.SymbolDef, ...]:
841  """Returns a tuple of MLIR symbols for the given tuple of index_var."""
842  return tuple(getattr(lang.S, i.name) for i in index_vars)
843
844
845def _mlir_dimensions_from_index_vars(
846    index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.DimDef, ...]:
847  """Returns a tuple of MLIR dimensions for the given tuple of index_var."""
848  return tuple(getattr(lang.D, i.name) for i in index_vars)
849
850
851def _mlir_tensor_type(
852    dtype: DType, shape: Tuple[int, ...],
853    attr: Optional[sparse_tensor.EncodingAttr]) -> ir.RankedTensorType:
854  """Returns an MLIR tensor type.
855
856  Args:
857    dtype: An DType object for the element data type of the tensor.
858    shape: A tuple of integer for the shape of the tensor.
859    attr: An optional MLIR sparse tensor attribute, only provided if the tensor
860      is a sparse tensor.
861
862  Returns:
863    An MLIR ranked tensor type.
864  """
865  ir_type = _mlir_type_from_taco_type(dtype)
866  return ir.RankedTensorType.get(shape, ir_type, attr)
867
868
869@dataclasses.dataclass(frozen=True)
870class _StructOpInfo:
871  """Information for generating a structured op in the linalg dialect.
872
873  This information is associated with an expression node that serves as the
874  root for an expression subtree implemented with a structured op.
875
876  Attributes:
877    dst_indices: A tuple of IndexVar, representing the result dimensions of the
878      structured op. This is used to construct the temporary variable for the
879      tensor to hold the structured op result.
880    dst_dims: A tuple of int, representing the result shape of the structured
881      op.
882    dst_dtype: A DType representing the data type of the structured op result.
883    dst_name: A string representing the name of the structured op result.
884    dst_format: An optional Format object representing the destination tensor
885      format. None represents a true dense tensor.
886  """
887  dst_indices: Tuple[IndexVar, ...]
888  dst_dims: Tuple[int, ...]
889  dst_dtype: DType
890  dst_name: str
891  dst_format: Optional[Format]
892
893  def __post_init__(self) -> None:
894    """Verifies the integrity of the attribute values."""
895    assert len(self.dst_indices) == len(self.dst_dims)
896
897  def emit_tensor_init(self) -> ir.RankedTensorType:
898    """Returns an initialization for the destination tensor."""
899    if self.dst_format is None or self.dst_format.rank() == 0:
900      # Initialize the dense tensor.
901      ir_type = _mlir_type_from_taco_type(self.dst_dtype)
902      tensor = linalg.InitTensorOp(self.dst_dims, ir_type).result
903      zero = arith.ConstantOp(ir_type, 0.0)
904      return linalg.fill(zero, outs=[tensor])
905
906    # Initialize the sparse tensor.
907    mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims,
908                                  self.dst_format.mlir_tensor_attr())
909    index_type = ir.IndexType.get()
910    return bufferization.AllocTensorOp(mlir_type, [], None, None)
911
912
913class _Stats:
914  """Information to describe how a tensor expression is implemented.
915
916  Currently, we only record the temporary tensors introduced for splitting the
917  original expression.
918  """
919
920  def __init__(self):
921    self._temps = []
922
923  def __repr__(self) -> str:
924    return f"_Stats({repr(self._temps)})"
925
926  def add_element(self, structop: _StructOpInfo):
927    """Adds a temporary tensor."""
928    self._temps.append(structop)
929
930  def get_total(self) -> int:
931    """Gets the total number of temporary tensors."""
932    return len(self._temps)
933
934  def _get_element(self, idx: int) -> _StructOpInfo:
935    """Gets the ith temporary tensor."""
936    assert idx < self.get_total()
937    return self._temps[idx]
938
939  def get_dimensions(self, idx: int) -> Tuple[int]:
940    """Gets the dimensions for the ith temporary tensor."""
941    return self._get_element(idx).dst_dims
942
943  def get_formats(self, idx: int) -> Tuple[ModeFormat]:
944    """Gets the ModeFormats for the ith temporary tensor."""
945    return tuple(self._get_element(idx).dst_format.format_pack.formats)
946
947
948class _SparseValueInfo(enum.Enum):
949  """Describes how a sparse tensor value is stored.
950  _UNPACKED: The sparse tensor value is stored as (coordnates, values) in
951    Python.
952  _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR
953    sparse tensor.
954  """
955  _UNPACKED = 0
956  _PACKED = 1
957
958
959@dataclasses.dataclass(frozen=True)
960class _Assignment:
961  """Records an assignment to a tensor T as T[indices] = expression."""
962  indices: Tuple["IndexVar", ...]
963  expression: "IndexExpr"
964
965
966class Tensor:
967  """The tensor class.
968
969  We support the TACO API tensor class with an alias of this class.
970
971  This class is part of the TACO API with the following methods:
972    insert: Inserts a value to the given coordinate in the tensor.
973    to_array: Returns a numpy ndarray for the tensor.
974
975  TACO API also defines the following arrtibutes for the class:
976    dtype: A dtype object representing the data type of the tensor.
977    format: A format object representing the storage format of the tensor.
978    name: A string object representing the name of the tensor.
979    order: An integral rank of the tensor.
980    shape: A list of integers representing the shape of the tensor.
981
982  We currently ignore the tensor dimension ordering for dense tensor.
983  """
984  _counter = _AtomicCounter()
985
986  def _get_unique_name(self) -> str:
987    """Returns a unique name for creating a new Tensor."""
988    return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}"
989
990  def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat],
991                                    Format]) -> None:
992    """Process the fmt argument for the Tensor constructor.
993
994    Args:
995      fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
996        this argument is a ModeFormat, uses this ModeFormat for all the tensor
997        dimensions. If this argument is a list of ModeFormat, the len of the
998        list should equal to the rank of the tensor. If this argument is a
999        format, uses it for the format of the tensor.
1000
1001    Raises:
1002      ValueError: If fmt is not one of the expected type or is inconsistent
1003        with the rank of the tensor. This is because fmt could be an users
1004        input.
1005    """
1006    if isinstance(fmt, ModeFormat):
1007      self._format = _make_format([fmt] * self.order)
1008    elif isinstance(fmt, list):
1009      if len(fmt) == self.order and isinstance(fmt[0], ModeFormat):
1010        self._format = _make_format(fmt)
1011      else:
1012        raise ValueError("Inconsistent shape and format: "
1013                         f"{self._shape}, {fmt}.")
1014    elif isinstance(fmt, Format):
1015      if fmt.rank() != self.order:
1016        raise ValueError("Inconsistent shape and format: "
1017                         f"{self._shape}, {fmt}.")
1018      else:
1019        self._format = fmt
1020    else:
1021      raise ValueError(f"Invalid format argument: {fmt}.")
1022
1023  def __init__(self,
1024               value_or_shape: Optional[Union[List[int], Tuple[int, ...],
1025                                              complex, float, int]] = None,
1026               fmt: Optional[Union[ModeFormat, List[ModeFormat],
1027                                   Format]] = None,
1028               dtype: Optional[DType] = None,
1029               name: Optional[str] = None,
1030               is_dense: bool = False):
1031    """The tensor constructor interface defined by TACO API.
1032
1033    Args:
1034      value_or_shape: This argument is optional and can be int, float,
1035        List[int], or Tuple[int, ...]. If this argument is an int or float,
1036        creates a scalar tensor and initializes it with the value. If this
1037        argument is a list or tuple of int, uses it as the shape to create a
1038        tensor.
1039      fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
1040        this argument is a ModeFormat, uses this ModeFormat for all the tensor
1041        dimensions. If this argument is a list of ModeFormat, the len of the
1042        list should equal to the rank of the tensor. If this argument is a
1043        format, uses it for the format of the tensor.
1044      dtype: An object of dtype, representing the data type of the tensor.
1045      name: A string name of the tensor. If a name is not given, creates a
1046        unique name for the tensor.
1047      is_dense: A boolean variable to indicate whether the tensor is a dense
1048        tensor without any sparsity annotation.
1049
1050    Raises:
1051      ValueError: If there is any inconsistency among the input arguments.
1052    """
1053    # Take care of the argument default values common to both sparse tensors
1054    # and dense tensors.
1055    dtype = dtype or DType(Type.FLOAT32)
1056    self._name = name or self._get_unique_name()
1057    self._assignment = None
1058    self._engine = None
1059    self._sparse_value_location = _SparseValueInfo._UNPACKED
1060    self._dense_storage = None
1061    self._dtype = dtype
1062
1063    if is_dense:
1064      assert (fmt is None)
1065      assert (isinstance(value_or_shape, tuple) or isinstance(
1066          value_or_shape, list)) and _all_instance_of(value_or_shape, int)
1067      self._shape = value_or_shape
1068      self._format = None
1069      return
1070
1071    fmt = fmt or ModeFormat.COMPRESSED
1072    # We currently use _coords and _values to host the sparse tensor value with
1073    # COO format, and _dense_storage to host the dense tensor value. We don't
1074    # support the conversion between the two storages.
1075    self._coords = []
1076    self._values = []
1077    self._stats = _Stats()
1078    if value_or_shape is None or isinstance(value_or_shape, int) or isinstance(
1079        value_or_shape, float) or isinstance(value_or_shape, complex):
1080      # Create a scalar tensor and ignore the fmt parameter.
1081      self._shape = []
1082      self._format = _make_format([], [])
1083      if value_or_shape is not None:
1084        self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value)
1085    elif (isinstance(value_or_shape, tuple) or isinstance(
1086        value_or_shape, list)) and _all_instance_of(value_or_shape, int):
1087      # Create a tensor with the specified shape and format.
1088      self._shape = list(value_or_shape)
1089      self._init_format(fmt)
1090    else:
1091      raise ValueError("Invalid first argument. "
1092                       "Must be a tuple or list for a shape or a single value"
1093                       f"if initializing a scalar tensor: {value_or_shape}.")
1094
1095  def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None:
1096    """Records the MLIR sparse tensor pointer."""
1097    self._sparse_value_location = _SparseValueInfo._PACKED
1098    self._packed_sparse_value = pointer
1099
1100  def is_unpacked(self) -> bool:
1101    """Returns true if the tensor value is not packed as MLIR sparse tensor."""
1102    return (self._sparse_value_location == _SparseValueInfo._UNPACKED)
1103
1104  def unpack(self) -> None:
1105    """Unpacks the MLIR sparse tensor representation."""
1106    if self.is_dense() or self.is_unpacked():
1107      return
1108
1109    # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
1110    # values and verify the values.
1111    rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
1112        self._packed_sparse_value, self._dtype.value)
1113    assert rank == self.order
1114    assert np.array_equal(self.shape, shape)
1115    assert nse == len(values)
1116    self._coords = indices
1117    self._values = values
1118    self._sparse_value_location = _SparseValueInfo._UNPACKED
1119
1120  def __repr__(self) -> str:
1121    self._sync_value()
1122    self.unpack()
1123    value_str = (f"{repr(self._dense_storage)})" if self.is_dense() else
1124                 f"{repr(self._coords)} {repr(self._values)})")
1125    return (f"Tensor(_name={repr(self._name)} "
1126            f"_dtype={repr(self._dtype)} : ") + value_str
1127
1128  def insert(self, coords: List[int], val: Union[complex, float, int]) -> None:
1129    """Inserts a value to the given coordinate.
1130
1131    Args:
1132      coords: A list of integer coordinates. The length of the list must be the
1133        same as the rank of the tensor.
1134      val: A value being inserted. It is either an integral or a floating point
1135        value. This value will be converted to the data type of the tensor.
1136
1137    Raises:
1138      ValueError: When there is any problem in the parameters.
1139    """
1140    if self.is_dense():
1141      raise ValueError("Insert method is not supported for dense tensors.")
1142    if self._assignment != None or not self.is_unpacked():
1143      raise ValueError(
1144          "Can't use Insert method for a tensor constructed from a file.")
1145    if not isinstance(coords, list):
1146      raise ValueError(f"Non list coordinate detected: {coords}.")
1147    if not _all_instance_of(coords, int):
1148      raise ValueError(f"Non integer coordinate detected: {coords}.")
1149    if (len(coords) != self.order or
1150        any([c < 0 or c >= self._shape[i] for i, c in enumerate(coords)])):
1151      raise ValueError("Invalid coordinate for rank: "
1152                       f"{self.order}, {coords}.")
1153
1154    if not isinstance(val, int) and not isinstance(
1155        val, float) and not isinstance(val, complex):
1156      raise ValueError(f"Value is neither int nor float: {val}.")
1157
1158    self._coords.append(tuple(coords))
1159    self._values.append(self._dtype.value(val))
1160
1161  def is_dense(self) -> bool:
1162    """Returns true if the tensor doesn't have sparsity annotation."""
1163    return self.order == 0 or self._format is None
1164
1165  def to_array(self) -> np.ndarray:
1166    """Returns the numpy array for the Tensor.
1167
1168    This is currenly only implemented for dense Tensor.
1169    """
1170    if not self.is_dense():
1171      raise ValueError("Conversion from non-dense Tensor "
1172                       "to numpy array not supported yet.")
1173
1174    self._sync_value()
1175
1176    return self._dense_storage
1177
1178  @staticmethod
1179  def from_array(array: np.ndarray) -> "Tensor":
1180    """Returns a dense tensor with the value copied from the input array.
1181
1182    We currently only support the conversion of float32 and float64 numpy arrays
1183    to Tensor.
1184
1185    Args:
1186      array: The numpy array that provides the data type, shape and value for
1187        the tensor.
1188
1189    Returns:
1190      A Tensor object.
1191
1192    Raises:
1193      ValueError if the data type of the numpy array is not supported.
1194    """
1195    if array.dtype != np.float32 and array.dtype != np.float64:
1196      raise ValueError(f"Expected floating point value type: {array.dtype}.")
1197    tensor = Tensor(
1198        array.shape,
1199        dtype=_nptype_to_taco_type(array.dtype.type),
1200        is_dense=True)
1201    tensor._dense_storage = np.copy(array)
1202    return tensor
1203
1204  @staticmethod
1205  def from_coo(
1206      coordinates: List[Tuple[int, ...]],
1207      values: List[_AnyRuntimeType],
1208      fmt: Format,
1209      dtype: DType,
1210  ) -> "Tensor":
1211    """Converts coordinates and values to a sparse tensor representation.
1212
1213    Args:
1214      coordinates: A list of coordinates with non-zero values.
1215      values: The non-zero values.
1216      fmt: The tensor storage format.
1217      dtype: The tensor element data type.
1218
1219    Returns:
1220      A tensor with the given non-zero values and storage format. The shape of
1221      the tensor has the minimum size for each dimension to make the given
1222      coordinates valid.
1223    """
1224    assert (isinstance(coordinates, List) and
1225            _all_instance_of(coordinates, Tuple))
1226    assert (isinstance(values, List) and _all_instance_of(values, dtype.value))
1227    assert isinstance(fmt, Format)
1228
1229    rank = fmt.rank()
1230    assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates)
1231
1232    # Find the maximum coordinate value for each dimension.
1233    max_coordinate = list(map(max, zip(*coordinates)))
1234    # The size of each dimension is one more that such a maximum coordinate
1235    # value.
1236    shape = [c + 1 for c in max_coordinate]
1237    tensor = Tensor(shape, fmt, dtype=dtype)
1238    tensor._coords = coordinates
1239    tensor._values = values
1240
1241    return tensor
1242
1243  @staticmethod
1244  def from_file(
1245      filename: str,
1246      fmt: Format,
1247      dtype: DType,
1248  ) -> "Tensor":
1249    """Constructs a sparse tensor using the COO-flavored values from a file.
1250
1251    Args:
1252      filename: A string for the name of the file that contains the sparse
1253        tensor data.
1254      fmt: The tensor storage format.
1255      dtype: The tensor element data type.
1256
1257    Returns:
1258      A tensor with the given non-zero values and storage format. The tensor
1259      value is stored as an MLIR sparse tensor.
1260    """
1261    sparse_tensor, shape = utils.create_sparse_tensor(filename,
1262                                                      fmt.format_pack.formats,
1263                                                      _dtype_to_mlir_str(dtype))
1264    tensor = Tensor(shape.tolist(), fmt, dtype=dtype)
1265    tensor._set_packed_sparse_tensor(sparse_tensor)
1266
1267    return tensor
1268
1269  def to_file(self, filename: str) -> None:
1270    """Output the tensor value to a file.
1271
1272    This method evaluates any pending assignment to the tensor and outputs the
1273    tensor value.
1274
1275    Args:
1276      filename: A string file name.
1277
1278    Raises:
1279       ValueError: If the tensor is dense, or an unpacked sparse tensor.
1280    """
1281    self._sync_value()
1282
1283    if self.is_dense():
1284      raise ValueError("Writing dense tensors without sparsity annotation to "
1285                       "file is not supported.")
1286
1287    if self.is_unpacked():
1288      raise ValueError("Writing unpacked sparse tensors to file is not "
1289                       "supported.")
1290
1291    utils.output_sparse_tensor(self._packed_sparse_value, filename,
1292                               self._format.format_pack.formats,
1293                               _dtype_to_mlir_str(self._dtype))
1294
1295  @property
1296  def dtype(self) -> DType:
1297    """Returns the data type for the Tensor."""
1298    return self._dtype
1299
1300  @property
1301  def format(self) -> Format:
1302    """Returns the storage format for the Tensor."""
1303    return self._format
1304
1305  @property
1306  def name(self) -> str:
1307    """Returns the name for the Tensor."""
1308    return self._name
1309
1310  @property
1311  def order(self) -> int:
1312    """Returns the rank of the Tensor."""
1313    return len(self._shape)
1314
1315  @property
1316  def shape(self) -> List[int]:
1317    """Returns the shape of the Tensor."""
1318    return self._shape
1319
1320  def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
1321    """Verifies and normalizes the indices to access the tensor.
1322
1323    Args:
1324      indices: The index expression used to access a tensor, which could be any
1325        Python object from user inputs.
1326
1327    Returns:
1328      A tuple of IndexVar.
1329
1330    Raises:
1331      ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
1332        a tuple of IndexVar for other tensors.
1333    """
1334    if self.order == 0:
1335      if not isinstance(indices, int) or indices != 0:
1336        raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
1337      return ()
1338
1339    if isinstance(indices, IndexVar):
1340      return (indices,)
1341    elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
1342      return indices
1343
1344    raise ValueError(f"Expected IndexVars: {indices}")
1345
1346  def __getitem__(self, key) -> "Access":
1347    """Verifies and processes a tensor access.
1348
1349    In the tensor index notation, a tensor access T[i, j] is represented as
1350    retrieving a value with key (i, j) from the tensor object T in Python. This
1351    routine verifies the key for the tensor access and returns a tensor access
1352    object.
1353
1354    Args:
1355      key: The key used to access the tensor, which could be any Python object
1356        from user inputs.
1357
1358    Returns:
1359      The corresponding tensor access object.
1360
1361    Raises:
1362      ValueError: If key is not an IndexVar or a tuple of IndexVar.
1363    """
1364    indices = self._verify_and_normalize_indices(key)
1365    return Access(self, indices)
1366
1367  def __setitem__(self, key, value) -> None:
1368    """Verifies and processes a tensor assignment.
1369
1370    In the tensor index notation, a tensor assignment "T[i, j] = ..." is
1371    represented as setting a value for a tensor object T via key (i, j) in
1372    Python. This routine verifies the key, evaluates the value, and assigns the
1373    value to the tensor.
1374
1375    We only support assignment of dense tensor currently.
1376
1377    Args:
1378      key: The key used to access the tensor, which could be any Python object
1379        from user inputs.
1380      value: The value assigned to the tensor, which could be any Python object
1381        from user inputs.
1382
1383    Raises:
1384      ValueError: If tensor is not a dense tensor, or the key is not an IndexVar
1385        or a tuple of IndexVar, or the length of the indices is not the same as
1386        the rank of the tensor.
1387    """
1388    indices = self._verify_and_normalize_indices(key)
1389    if len(indices) != self.order:
1390      raise ValueError("Mismatch between indices and tensor rank: "
1391                       f"len({indices}) != {self.order}.")
1392
1393    self._assignment = _Assignment(indices, value)
1394    self._engine = None
1395
1396  def compile(self, force_recompile: bool = False) -> None:
1397    """Compiles the tensor assignment to an execution engine.
1398
1399    Calling compile the second time does not do anything unless
1400    force_recompile is True.
1401
1402    Args:
1403      force_recompile: A boolean value to enable recompilation, such as for the
1404        purpose of timing.
1405
1406    Raises:
1407      ValueError: If the assignment is not proper or not supported.
1408    """
1409    if self._assignment is None or (self._engine is not None and
1410                                    not force_recompile):
1411      return
1412
1413    self._engine = self._assignment.expression.compile(self,
1414                                                       self._assignment.indices)
1415
1416  def compute(self) -> None:
1417    """Executes the engine for the tensor assignment.
1418
1419    Raises:
1420      ValueError: If the assignment hasn't been compiled yet.
1421    """
1422    if self._assignment is None:
1423      return
1424
1425    if self._engine is None:
1426      raise ValueError("Need to invoke compile() before invoking compute().")
1427
1428    input_accesses = self._assignment.expression.get_input_accesses()
1429    # Gather the pointers for the input buffers.
1430    input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
1431    if self.is_dense():
1432      # The pointer to receive dense output is the first argument to the
1433      # execution engine.
1434      arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers
1435    else:
1436      # The pointer to receive the sparse tensor output is the last argument
1437      # to the execution engine and is a pointer to pointer of char.
1438      arg_pointers = input_pointers + [
1439          ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
1440      ]
1441
1442    # Invoke the execution engine to run the module.
1443    self._engine.invoke(_ENTRY_NAME, *arg_pointers)
1444
1445    # Retrieve the result.
1446    if self.is_dense():
1447      result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
1448      assert isinstance(result, np.ndarray)
1449      self._dense_storage = result
1450    else:
1451      self._set_packed_sparse_tensor(arg_pointers[-1][0])
1452
1453    self._assignment = None
1454    self._engine = None
1455
1456  def evaluate(self) -> None:
1457    """Evaluates the tensor assignment."""
1458    self.compile()
1459    self.compute()
1460
1461  def _sync_value(self) -> None:
1462    """Updates the tensor value by evaluating the pending assignment."""
1463    if self._assignment is not None:
1464      self.evaluate()
1465
1466  def mlir_tensor_type(self) -> ir.RankedTensorType:
1467    """Returns the MLIR type for the tensor."""
1468    mlir_attr = (None if (self._format is None or self.order == 0) else
1469                 self._format.mlir_tensor_attr())
1470    return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
1471
1472  def dense_dst_ctype_pointer(self) -> ctypes.pointer:
1473    """Returns the ctypes pointer for the pointer to an MemRefDescriptor.
1474
1475    For a dense tensor output, the MLIR compiler allocates the storage for
1476    the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for
1477    receiving the tensor.
1478    """
1479    assert self.is_dense()
1480    mem_ref_desc = runtime.make_nd_memref_descriptor(
1481        self.order, np.ctypeslib.as_ctypes_type(self.dtype.value))()
1482    return ctypes.pointer(ctypes.pointer(mem_ref_desc))
1483
1484  def ctype_pointer(self) -> ctypes.pointer:
1485    """Returns the ctypes pointer for the pointer to the input tensor."""
1486    if self.is_dense():
1487      if self._dense_storage is None:
1488        self._dense_storage = np.zeros(self._shape, self._dtype.value)
1489      return _ctype_pointer_from_array(self._dense_storage)
1490
1491    if self.is_unpacked():
1492      shape = np.array(self._shape, np.int64)
1493      indices = np.array(self._coords, np.int64)
1494      values = np.array(self._values, self._dtype.value)
1495      perm, sparse = self.format.get_permutation_and_sparsity()
1496      ptr = utils.coo_tensor_to_sparse_tensor(shape, values, indices, perm,
1497                                              sparse)
1498    else:
1499      ptr = self._packed_sparse_value
1500
1501    return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
1502
1503  def get_scalar_value(self) -> _AnyRuntimeType:
1504    """Returns the value for the scalar tensor.
1505
1506    This method also evaluates the assignment to the tensor.
1507
1508    Raises:
1509      ValueError: If the tensor is not a scalar.
1510    """
1511    if self.order != 0:
1512      raise ValueError(f"Expected a scalar tensor, got: rank={self.order}")
1513
1514    self._sync_value()
1515    return self._dense_storage
1516
1517
1518  def get_coordinates_and_values(
1519      self) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]:
1520    """Returns the coordinates and values for the non-zero elements.
1521
1522    This method also evaluates the assignment to the tensor and unpack the
1523    sparse tensor.
1524    """
1525    self._sync_value()
1526
1527    if not self.is_dense():
1528      self.unpack()
1529      return (self._coords, self._values)
1530
1531    if self.order == 0:
1532      return ([], self._dense_storage)
1533
1534    # Coordinates for non-zero elements, grouped by dimensions.
1535    coords_by_dims = self._dense_storage.nonzero()
1536    # Coordinates for non-zero elements, grouped by elements.
1537    coords = np.transpose(coords_by_dims)
1538    values = self._dense_storage[coords_by_dims]
1539    return (coords, values)
1540
1541  def _record_stats(self, structop: "_StructOpInfo"):
1542    """Collects information for temporary tensors."""
1543    # Exclude user specified destination tensors.
1544    if structop.dst_name == self.name:
1545      return
1546
1547    self._stats.add_element(structop)
1548
1549
1550def _emit_operand(op_def: lang.LinalgOpDef, indices: Tuple[IndexVar, ...],
1551                  name: str, kind: lang.OperandKind) -> lang.OperandDef:
1552  """Emits an operand for a tensor access in the current linalg operation.
1553
1554  Args:
1555    op_def: A LinalgOpDef representing the current linalg dialect operation.
1556    indices: A tuple of IndexVar used to access the tensor.
1557    name: A unique string name of the tensor.
1558    kind: An OperandKind for the operand.
1559
1560  Returns:
1561    An OperandDef representing the operand.
1562  """
1563  dim_sym = _mlir_symbols_from_index_vars(indices)
1564  opnd = lang.OperandDef(kind, lang.T, dim_sym)
1565  op_def.add_operand(name, opnd)
1566  return opnd
1567
1568
1569@dataclasses.dataclass(frozen=True)
1570class _DimInfo:
1571  """Information for an operand dimension.
1572
1573  Attributes:
1574    dim: An integer for the size of the dimension.
1575    mode_format: A ModeFormat for the dimension sparsity.
1576  """
1577  dim: int
1578  mode_format: ModeFormat
1579
1580
1581def _get_dummy_dim_info() -> _DimInfo:
1582  """Constructs the _DimInfo for an index used in tensor expressions."""
1583  return _DimInfo(-1, ModeFormat.DENSE)
1584
1585
1586@dataclasses.dataclass()
1587class _ExprInfo:
1588  """Expression information for validation and code generation.
1589
1590  Attributes:
1591    src_indices: A tuple of IndexVar for the indices used by the tensors in the
1592      expression tree.
1593    dim_infos: A tuple of _DimInfo, representing the dimension information
1594      corresponding to the src_indices.
1595    reduce_indices: A set of IndexVar for the indices reduced by the expression.
1596    acc_reduce_indices: An accumulated set of IndexVar for the indices reduced
1597      by the expression and its children.
1598    structop_info: Information to support the code generation for a structured
1599      op in the linalg dialect, if the corresponding expression node is the root
1600      of a subtree for a structured op.
1601    mlir_value: The MLIR value generated for the structured op.
1602  """
1603  src_indices: Tuple[IndexVar, ...]
1604  dim_infos: Tuple[_DimInfo, ...]
1605  reduce_indices: Optional[Set[IndexVar]] = None
1606  acc_reduce_indices: Optional[Set[IndexVar]] = None
1607  structop_info: Optional[_StructOpInfo] = None
1608  mlir_value: Optional[ir.Value] = None
1609
1610  def __post_init__(self) -> None:
1611    """Verifies and fix up attribute values.
1612
1613    Verifies the consistency of the attributes and modifies the default values
1614    to support convenient initializer syntax.
1615    """
1616    assert len(self.src_indices) == len(self.dim_infos)
1617    self.reduce_indices = self.reduce_indices or set()
1618    self.acc_reduce_indices = self.acc_reduce_indices or set()
1619
1620
1621@dataclasses.dataclass(frozen=True)
1622class Access(IndexExpr):
1623  """The tensor access class.
1624
1625  We support the TACO API access class with an alias of this class.
1626
1627  Attributes:
1628    tensor: A Tensor being accessed.
1629    indices: A tuple of IndexVar, representing the indices used to access the
1630      Tensor.
1631  """
1632  tensor: Tensor
1633  indices: Tuple[IndexVar, ...]
1634
1635  def __post_init__(self) -> None:
1636    """Verifies the tensor and indices for a tensor access.
1637
1638    Raises:
1639       ValueError: If indices is not a list of IndexVar or the len of indices
1640       doesn't equal to the rank of the tensor.
1641    """
1642    if (not isinstance(self.indices, tuple) or
1643        not _all_instance_of(self.indices, IndexVar)):
1644      raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.")
1645    if self.tensor.order != len(self.indices):
1646      raise ValueError("Invalid indices for rank: "
1647                       f"str{self.tensor.order} != len({str(self.indices)}).")
1648
1649  def __repr__(self) -> str:
1650    # The Tensor __repr__ method evaluates the pending assignment to the tensor.
1651    # We want to define the __repr__ method here to avoid such evaluation of the
1652    # tensor assignment.
1653    indices_str = ", ".join(map(lambda i: i.name, self.indices))
1654    return (f"Tensor({self.tensor.name}) " f"Indices({indices_str})")
1655
1656  def _emit_expression(
1657      self,
1658      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
1659      expr_to_info: _ExprInfoDict,
1660  ) -> lang.ScalarExpression:
1661    """Emits a linalg dialect TensorUse expression for the tensor access."""
1662    assert self in expr_to_opnd
1663    dims = _mlir_dimensions_from_index_vars(self.indices)
1664    return lang.TensorUse(expr_to_opnd[self], dims)
1665
1666  def _visit(self,
1667             func: _ExprVisitor,
1668             args,
1669             *,
1670             leaf_checker: _SubtreeLeafChecker = None) -> None:
1671    if leaf_checker:
1672      assert leaf_checker(self, *args)
1673    func(self, *args)
1674
1675  def dtype(self) -> DType:
1676    return self.tensor.dtype
1677
1678
1679def _gather_input_accesses_index_vars(
1680    expr: IndexExpr,
1681    input_accesses: List[Access],
1682) -> None:
1683  """Collects Access nodes."""
1684  if isinstance(expr, Access) and expr not in input_accesses:
1685    input_accesses.append(expr)
1686
1687
1688def _op_ceil(__a: Any) -> Any:
1689  """A _UnaryOp object for operation ceil."""
1690  pass
1691
1692
1693def _op_floor(__a: Any) -> Any:
1694  """A _UnaryOp object for operation floor."""
1695  pass
1696
1697
1698def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType:
1699  """Returns the linalg dialect function object for the given operation."""
1700  op_to_callable = {
1701      operator.abs: lang.UnaryFn.abs,
1702      operator.neg: lang.UnaryFn.negf,
1703      _op_ceil: lang.UnaryFn.ceil,
1704      _op_floor: lang.UnaryFn.floor,
1705  }
1706  return op_to_callable[op]
1707
1708
1709@dataclasses.dataclass(frozen=True)
1710class _UnaryExpr(IndexExpr):
1711  """The representation for a Unary operation.
1712
1713  Attributes:
1714  op: A _UnaryOp representing the operation.
1715  a: An IndexExpr representing the operand for the operation.
1716  """
1717  op: _BinaryOp
1718  a: IndexExpr
1719
1720  def __post_init__(self) -> None:
1721    """Verifies that the operand being added is an IndexExpr."""
1722    assert isinstance(self.a, IndexExpr)
1723
1724  def _emit_expression(
1725      self,
1726      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
1727      expr_to_info: _ExprInfoDict,
1728  ) -> lang.ScalarExpression:
1729    """Emits the expression tree and returns the expression."""
1730    # The current expression node is an internal node of the structured op.
1731    if self not in expr_to_opnd:
1732      a = self.a._emit_expression(expr_to_opnd, expr_to_info)
1733      return _op_unary_to_callable(self.op)(a)
1734
1735    # The current expression is a leaf node of the structured op. That is, it is
1736    # a temporary tensor generated by its child structured op.
1737    op_info = expr_to_info[self].structop_info
1738    assert op_info is not None
1739    dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
1740    return lang.TensorUse(expr_to_opnd[self], dims)
1741
1742  def _visit(self,
1743             func: _ExprVisitor,
1744             args,
1745             *,
1746             leaf_checker: _SubtreeLeafChecker = None) -> None:
1747    """A post-order visitor."""
1748    if leaf_checker is None or not leaf_checker(self, *args):
1749      self.a._visit(func, args, leaf_checker=leaf_checker)
1750    func(self, *args)
1751
1752  def dtype(self) -> DType:
1753    """Returns the data type of the operation."""
1754    return self.a.dtype()
1755
1756
1757def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
1758  """Returns the linalg dialect function object for the given operation."""
1759  op_to_callable = {
1760      operator.add: lang.BinaryFn.add,
1761      operator.sub: lang.BinaryFn.sub,
1762      operator.mul: lang.BinaryFn.mul,
1763  }
1764  return op_to_callable[op]
1765
1766@dataclasses.dataclass(frozen=True)
1767class _BinaryExpr(IndexExpr):
1768  """The representation for a binary operation.
1769
1770  Attributes:
1771  op: A _BinaryOp representing the binary operation.
1772  a: An IndexExpr representing the first operand of the operation.
1773  b: An IndexExpr representing the second operand of the operation.
1774  """
1775  op: _BinaryOp
1776  a: IndexExpr
1777  b: IndexExpr
1778
1779  def __post_init__(self) -> None:
1780    """Verifies that the operands being added are IndexExpr."""
1781    assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr)
1782
1783  def _emit_expression(
1784      self,
1785      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
1786      expr_to_info: _ExprInfoDict,
1787  ) -> lang.ScalarExpression:
1788    """Emits the expression tree and returns the expression."""
1789    # The current expression node is an internal node of the structured op.
1790    if self not in expr_to_opnd:
1791      a = self.a._emit_expression(expr_to_opnd, expr_to_info)
1792      b = self.b._emit_expression(expr_to_opnd, expr_to_info)
1793      return _op_to_callable(self.op)(a, b)
1794
1795    # The current expression is a leaf node of the structured op. That is, it is
1796    # a temporary tensor generated by its child structured op.
1797    op_info = expr_to_info[self].structop_info
1798    assert op_info is not None
1799    dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
1800    return lang.TensorUse(expr_to_opnd[self], dims)
1801
1802  def _visit(self,
1803             func: _ExprVisitor,
1804             args,
1805             *,
1806             leaf_checker: _SubtreeLeafChecker = None) -> None:
1807    """A post-order visitor."""
1808    if leaf_checker is None or not leaf_checker(self, *args):
1809      self.a._visit(func, args, leaf_checker=leaf_checker)
1810      self.b._visit(func, args, leaf_checker=leaf_checker)
1811    func(self, *args)
1812
1813  def dtype(self) -> DType:
1814    """Returns the data type of the binary operation."""
1815    return self.a.dtype()
1816
1817
1818def _validate_and_collect_dim_info(
1819    index_to_dim_info: Dict[IndexVar, _DimInfo],
1820    indices: Tuple[IndexVar, ...],
1821    dim_infos: Tuple[_DimInfo, ...],
1822    expr: _BinaryExpr,
1823) -> None:
1824  """Validates and collects the dimension information for an index notation.
1825
1826  Validates (indices, dim_infos) against the information collected from other
1827  source operands and is represented by index_to_dim_info. In particular, we
1828  ensure that each IndexVar corresponds to only one dimension size. We also
1829  aggregate the new information represented in (indices, dim_infos) to
1830  index_to_dim_info.
1831
1832  Args:
1833    index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the
1834      previous operands.
1835    indices: The IndexVars to be validated.
1836    dim_infos: The dimension information for the IndexVars to be validated.
1837    expr: The binary expression where (indices, dim_infos) is used.
1838
1839  Raises:
1840    ValueError if there is any problem in the IndexVars or dimensional values.
1841  """
1842  assert len(indices) == len(dim_infos)
1843  for i, d in zip(indices, dim_infos):
1844    if i not in index_to_dim_info:
1845      index_to_dim_info[i] = d
1846    else:
1847      dim = index_to_dim_info[i].dim
1848      if dim == -1 or d.dim == -1:
1849        dim = dim if dim != -1 else d.dim
1850      elif dim != d.dim:
1851        raise ValueError(f"Inconsistent source dimension for {i}: "
1852                         f"{d.dim} vs {dim}")
1853      mode_format = _mode_format_estimator(expr.op)(
1854          index_to_dim_info[i].mode_format, d.mode_format)
1855      index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
1856
1857
1858def _validate_and_collect_expr_info(
1859    expr: IndexExpr,
1860    expr_to_info: _ExprInfoDict,
1861) -> None:
1862  """Validates dimension information and constructs _ExprInfo.
1863
1864  Validates that dimensional values for the same IndexVar are the same. Collects
1865  a list of IndexVar used by the expression and their corresponding dimensional
1866  values. Constructs an _ExprInfo object to record the information for the
1867  IndexExpr.
1868
1869  This routine is passed to the post-order visitor as an _ExprVisitor object.
1870
1871  Args:
1872    expr: The IndexExpr being validated.
1873    expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
1874      expression information.
1875
1876  Raises:
1877    ValueError if there is any problem in the IndexVars or dimensional values.
1878  """
1879  # Objects of class Access can be shared by different expressions. Avoid
1880  # processing Access objects multiple times by skipping the processing if expr
1881  # is already in the dictionary.
1882  if expr in expr_to_info:
1883    return
1884
1885  if isinstance(expr, IndexVar):
1886    src_indices = expr,  # A tuple with one element.
1887    dim_infos = _get_dummy_dim_info(),  # A tuple with one element.
1888  elif isinstance(expr, Access):
1889    src_indices = expr.indices
1890    src_dims = tuple(expr.tensor.shape)
1891    if expr.tensor.format is None:
1892      # Treat each dimension of a dense tensor as DENSE for the purpose of
1893      # calculating temporary tensor storage format.
1894      mode_formats = tuple([ModeFormat.DENSE] * len(src_dims))
1895    else:
1896      mode_formats = tuple(expr.tensor.format.format_pack.formats)
1897    assert len(src_dims) == len(mode_formats)
1898    dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
1899  elif isinstance(expr, _UnaryExpr):
1900    a_info = expr_to_info[expr.a]
1901    index_to_dim_info = {
1902        i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
1903    }
1904    # Here we rely on the fact that dictionaries keep the insertion order for
1905    # keys and values.
1906    src_indices = tuple(index_to_dim_info.keys())
1907    dim_infos = tuple(index_to_dim_info.values())
1908  else:
1909    assert isinstance(expr, _BinaryExpr)
1910    a_info = expr_to_info[expr.a]
1911    index_to_dim_info = {
1912        i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
1913    }
1914    b_info = expr_to_info[expr.b]
1915    _validate_and_collect_dim_info(index_to_dim_info, b_info.src_indices,
1916                                   b_info.dim_infos, expr)
1917    # Here we rely on the fact that dictionaries keep the insertion order for
1918    # keys and values.
1919    src_indices = tuple(index_to_dim_info.keys())
1920    dim_infos = tuple(index_to_dim_info.values())
1921
1922  expr_to_info[expr] = _ExprInfo(src_indices, dim_infos)
1923
1924
1925def _mark_structured_op_root(
1926    expr: IndexExpr,
1927    reduce_index: IndexVar,
1928    expr_to_info: _ExprInfoDict,
1929) -> None:
1930  """Identifies the root expression for a structured op in the linalg dialect.
1931
1932  An linalg structured op can only perform reduction on the whole expression.
1933  For a TACO tensor algebra expression, the reduction on an IndexVar is done at
1934  the smallest expression that contains all the uses of the IndexVar. If such an
1935  expression is only part of the whole expression, we need to split this
1936  sub-expression tree out from its parent and implement the sub-expression as a
1937  structured op.
1938
1939  This routine identifies the root expression node for performing a reduction on
1940  the given IndexVar. If the reduction of the given IndexVar should be performed
1941  on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices
1942
1943  Args:
1944    expr: The root IndexExpr for the tensor algebra expression.
1945    reduce_index: The IndexVar which we want to find out the proper expression
1946      to perform a reduction.
1947    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1948
1949  Raises:
1950      ValueError: If the expression is not proper or not supported.
1951  """
1952  expr_info = expr_to_info[expr]
1953  if isinstance(expr, Access):
1954    # Handle simple reduction expression in the format of A[i] = B[i, j].
1955    if reduce_index in expr_info.src_indices:
1956      expr_info.reduce_indices.add(reduce_index)
1957    return
1958  elif isinstance(expr, IndexVar):
1959    # A[i] = B[i] + j is not allowed.
1960    raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.")
1961
1962  assert (isinstance(expr, _BinaryExpr))
1963  a_info = expr_to_info[expr.a]
1964  b_info = expr_to_info[expr.b]
1965
1966  if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
1967    expr_info.reduce_indices.add(reduce_index)
1968    return
1969
1970  if reduce_index in a_info.src_indices:
1971    _mark_structured_op_root(expr.a, reduce_index, expr_to_info)
1972  elif reduce_index in b_info.src_indices:
1973    _mark_structured_op_root(expr.b, reduce_index, expr_to_info)
1974  else:
1975    assert False, "Unreachable path"
1976
1977
1978def _accumulate_reduce_indices(
1979    expr: IndexExpr,
1980    expr_to_info: _ExprInfoDict,
1981) -> None:
1982  """Propagates reduction indices from child expressions to parent expressions.
1983
1984  This routine is passed to the post-order visitor as an _ExprVisitor object.
1985
1986  Args:
1987    expr: The IndexExpr being visited.
1988    expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
1989      expression information.
1990  """
1991  assert expr in expr_to_info
1992  expr_info = expr_to_info[expr]
1993
1994  if isinstance(expr, _BinaryExpr):
1995    a_info = expr_to_info[expr.a]
1996    b_info = expr_to_info[expr.b]
1997    expr_info.acc_reduce_indices = (
1998        a_info.acc_reduce_indices | b_info.acc_reduce_indices
1999        | expr_info.reduce_indices)
2000  elif isinstance(expr, _UnaryExpr):
2001    a_info = expr_to_info[expr.a]
2002    expr_info.acc_reduce_indices = (
2003        a_info.acc_reduce_indices | expr_info.reduce_indices)
2004  elif isinstance(expr, IndexVar):
2005    # If an IndexVar is reducing itself, it means the IndexVar is outside the
2006    # iteration domain. This usage is now allowed and we should emit an error
2007    # before reaching here.
2008    assert not expr_info.reduce_indices
2009  else:
2010    assert isinstance(expr, Access)
2011    # Handle simple reduction expression in the format of A[i] = B[i, j].
2012    expr_info.acc_reduce_indices = expr_info.reduce_indices
2013
2014
2015
2016def _gather_structured_op(
2017    expr: IndexExpr,
2018    expr_to_info: _ExprInfoDict,
2019    structop_roots: List[IndexExpr],
2020) -> None:
2021  """Adds structured op root expression information to structop_roots.
2022
2023  This routine is passed to the post-order visitor as an _ExprVisitor object.
2024
2025  Args:
2026    expr: The IndexExpr being visited.
2027    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
2028    structop_roots: The resulting list of IndexExpr that are the roots for
2029      linalg structured ops.
2030  """
2031  if not expr_to_info[expr].reduce_indices:
2032    return
2033
2034  # If the expression is the root for reducing some indices, collect the indices
2035  # and dimensions for the reduction result.
2036  dst_indices = []
2037  dst_dims = []
2038  mode_fmts = []
2039  for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos):
2040    if i not in expr_to_info[expr].acc_reduce_indices:
2041      dst_indices.append(i)
2042      dst_dims.append(d.dim)
2043      mode_fmts.append(d.mode_format)
2044
2045  # Add the information to the dictionary.
2046  op_info = _StructOpInfo(
2047      tuple(dst_indices),
2048      tuple(dst_dims),
2049      expr.dtype(),
2050      f"temp{len(structop_roots)}",
2051      _make_format(mode_fmts),
2052  )
2053  expr_to_info[expr].structop_info = op_info
2054
2055  # Add the expression to the list of structured op roots.
2056  structop_roots.append(expr)
2057
2058
2059def _is_structured_op_leaf(
2060    expr: IndexExpr,
2061    root: IndexExpr,
2062    expr_to_info: _ExprInfoDict,
2063    *unused_args,
2064) -> bool:
2065  """Returns true iff the expression is a leaf node for a structured op.
2066
2067  The root of a structured op is a leaf of its parent structured op that uses
2068  its result. An expression node is a leaf node for the current structured op if
2069  it is an Access node or the root for a structured op that is not the current
2070  structured op.
2071
2072  This routine is passed to the post-order visitor as a _SubtreeLeafChecker
2073  object. Because the post-order visitor pass the same parameters to both
2074  _SubtreeLeafChecker and _ExprVisitor, this routine may received unused
2075  parameters.
2076
2077  Args:
2078    expr: The IndexExpr being visited.
2079    root: The root of the current structured op.
2080    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
2081
2082  Returns:
2083    True if the current IndexExpr is a leaf for the current structured op.
2084  """
2085  return (expr != root and
2086          expr_to_info[expr].structop_info is not None) or isinstance(
2087              expr, Access) or isinstance(expr, IndexVar)
2088
2089
2090def _gather_structured_op_input(
2091    expr: IndexExpr,
2092    root: IndexExpr,
2093    expr_to_info: _ExprInfoDict,
2094    structop_inputs: List[IndexExpr],
2095) -> None:
2096  """Adds the IndexExpr to structop_inputs if it is an input.
2097
2098  If the current IndexExpr is an input for the current structured op, adds it to
2099  structop_inputs. The current IndexExpr is an input if it is an Access node or
2100  if it is the root for a structured op that is not the current structured op.
2101
2102  This routine is passed to the post-order visitor as an _ExprVisitor object.
2103
2104  Args:
2105    expr: The IndexExpr being visited.
2106    root: The root of the current structured op.
2107    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
2108    structop_inputs: The resulting list of IndexExpr that provide input to the
2109      current structured op.
2110  """
2111  if ((expr != root or isinstance(expr, Access)) and
2112      expr not in structop_inputs) and (isinstance(expr, Access) or
2113                                        (expr in expr_to_info and
2114                                         expr_to_info[expr].structop_info)):
2115    structop_inputs.append(expr)
2116
2117
2118def _emit_structured_op_input(
2119    expr: IndexExpr,
2120    expr_to_info: _ExprInfoDict,
2121    op_def: lang.LinalgOpDef,
2122) -> lang.OperandDef:
2123  """Emits OperandDef in the linalg dialect for the input IndexExpr.
2124
2125  Args:
2126    expr: The input IndexExpr for the current structured op.
2127    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
2128    op_def: The linalg operation for the current structured op.
2129
2130  Returns:
2131    An OperandDef in the linalg dialect for the input IndexExpr.
2132  """
2133  op_info = expr_to_info[expr].structop_info
2134  if op_info and not isinstance(expr, Access):
2135    # The input is a temporary tensor produced by another structured op.
2136    indices = op_info.dst_indices
2137    name = op_info.dst_name
2138  else:
2139    # The input is a user provided tensor.
2140    assert isinstance(expr, Access)
2141    indices = expr.indices
2142    name = expr.tensor.name
2143
2144  dim_sym = _mlir_symbols_from_index_vars(indices)
2145  opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
2146  op_def.add_operand(name, opnd)
2147  return opnd
2148
2149
2150def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
2151  """Build a unary operation ceil.
2152
2153    Args:
2154      a: The operand, which could be any Python object from user inputs.
2155      op: An _UnaryOp object representing the operation.
2156
2157    Returns:
2158      A _UnaryExpr object representing the operation.
2159
2160    Raises:
2161      ValueError: If a is not an IndexExpr.
2162    """
2163  if not isinstance(a, Access):
2164    raise ValueError(f"Expected an Access Operand: {a}")
2165  return a._build_unary_expr(op)
2166
2167
2168def ceil(a: Access) -> "_UnaryExpr":
2169  """Defines the operation ceil.
2170
2171    Args:
2172      a: The operand, which could be any Python object from user inputs.
2173
2174    Returns:
2175      A _UnaryExpr object representing the operation.
2176
2177    Raises:
2178      ValueError: If a is not an IndexExpr.
2179    """
2180  return _check_and_build_unary(a, _op_ceil)
2181
2182
2183def floor(a: Access) -> "_UnaryExpr":
2184  """Defines the operation floor.
2185
2186    Args:
2187      a: The operand, which could be any Python object from user inputs.
2188
2189    Returns:
2190      A _UnaryExpr object representing the operation.
2191
2192    Raises:
2193      ValueError: If a is not an IndexExpr.
2194    """
2195  return _check_and_build_unary(a, _op_floor)
2196