1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5"""Experimental MLIR-PyTACO with sparse tensor support. 6 7See http://tensor-compiler.org/ for TACO tensor compiler. 8 9This module implements the Python classes for PyTACO index notation. These 10include classes for data types, tensor dimension formats (aka mode formats), 11tensor dimension orderings (aka mode ordering), tensor storage formats, and 12tensors. 13 14The PyTACO API doesn't follow the naming conversion required by the style guide 15for this module. As such, we first implement the supporting classes and routines 16following the style guide, and then define the type aliases and constants to 17support the PyTACO API in the pytaco_api module. 18""" 19 20from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union 21 22import abc 23import ctypes 24import dataclasses 25import enum 26import numpy as np 27import functools 28import operator 29import os 30import threading 31 32# Import MLIR related modules. 33from mlir import execution_engine 34from mlir import ir 35from mlir import runtime 36from mlir.dialects import arith 37from mlir.dialects import bufferization 38from mlir.dialects import builtin 39from mlir.dialects import func 40from mlir.dialects import linalg 41from mlir.dialects import sparse_tensor 42from mlir.dialects.linalg.opdsl import lang 43 44from . import mlir_pytaco_utils as utils 45 46# TACO naming prefixes. 47_TACO_INDEX_PREFIX = "i" 48_TACO_TENSOR_PREFIX = "A" 49 50# Bitwidths for pointers and indices. 51_POINTER_BIT_WIDTH = 0 52_INDEX_BIT_WIDTH = 0 53# The entry point to the JIT compiled program. 54_ENTRY_NAME = "main" 55 56# Type aliases for type annotation. 57_UnaryOp = Callable[[Any], Any] 58_BinaryOp = Callable[[Any, Any], Any] 59_ExprVisitor = Callable[..., None] 60_ExprInfoDict = Dict["IndexExpr", "_ExprInfo"] 61_LogicalOp = Callable[[bool, bool], bool] 62_ModeFormatOp = Callable[["ModeFormat", "ModeFormat"], "ModeFormat"] 63_SubtreeLeafChecker = Optional[Callable[..., bool]] 64 65 66class Type(enum.Enum): 67 """The data types supported by TACO. 68 69 We use numpy data types to implement the enum data types. 70 """ 71 INT8 = np.int8 72 INT16 = np.int16 73 INT32 = np.int32 74 INT64 = np.int64 75 FLOAT16 = np.float16 76 FLOAT32 = np.float32 77 FLOAT64 = np.float64 78 COMPLEX64 = np.complex64 79 COMPLEX128 = np.complex128 80 81 82# All floating point type enums. 83_FLOAT_TYPES = (Type.FLOAT16, Type.FLOAT32, Type.FLOAT64) 84# All integral type enums. 85_INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64) 86# All complex type enums. 87_COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128) 88# Type alias for any numpy type used to implement the runtime support for the 89# enum data types. 90_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16, 91 np.float32, np.float64, np.complex64, np.complex128] 92 93 94@dataclasses.dataclass(frozen=True) 95class DType: 96 """The data type class. 97 98 We support the TACO API dtype class with an alias of this class. 99 100 The following methods are defined by the TACO API: 101 is_float: Returns whether the data type represents a floating point value. 102 is_int: Returns whether the data type represents an integral value. 103 104 Attributes: 105 kind: A Type enum representing the data type. 106 value: The numpy data type for the TACO data type. 107 """ 108 kind: Type = Type.FLOAT32 109 110 def is_float(self) -> bool: 111 """Returns whether the data type represents a floating point value.""" 112 return self.kind in _FLOAT_TYPES 113 114 def is_int(self) -> bool: 115 """Returns whether the data type represents an integral value.""" 116 return self.kind in _INT_TYPES 117 118 def is_complex(self) -> bool: 119 """Returns whether the data type represents a complex value.""" 120 return self.kind in _COMPLEX_TYPES 121 122 @property 123 def value(self) -> _AnyRuntimeType: 124 """Returns the numpy dtype for the data type.""" 125 return self.kind.value 126 127 128def _dtype_to_mlir_str(dtype: DType) -> str: 129 """Returns the MLIR string for the given dtype.""" 130 dtype_to_str = { 131 Type.INT16: "i8", 132 Type.INT16: "i16", 133 Type.INT32: "i32", 134 Type.INT64: "i64", 135 Type.FLOAT16: "f16", 136 Type.FLOAT32: "f32", 137 Type.FLOAT64: "f64", 138 Type.COMPLEX64: "complex<f32>", 139 Type.COMPLEX128: "complex<f64>" 140 } 141 return dtype_to_str[dtype.kind] 142 143 144def _nptype_to_taco_type(ty: np.dtype) -> DType: 145 """Returns the TACO type for the given numpy type.""" 146 nptype_to_dtype = { 147 np.int8: Type.INT8, 148 np.int16: Type.INT16, 149 np.int32: Type.INT32, 150 np.int64: Type.INT64, 151 np.float16: Type.FLOAT16, 152 np.float32: Type.FLOAT32, 153 np.float64: Type.FLOAT64, 154 np.complex64: Type.COMPLEX64, 155 np.complex128: Type.COMPLEX128 156 } 157 return DType(nptype_to_dtype[ty]) 158 159 160def _mlir_type_from_taco_type(dtype: DType) -> ir.Type: 161 """Returns the MLIR type corresponding to the given TACO type.""" 162 dtype_to_irtype = { 163 Type.INT8: ir.IntegerType.get_signless(8), 164 Type.INT16: ir.IntegerType.get_signless(16), 165 Type.INT32: ir.IntegerType.get_signless(32), 166 Type.INT64: ir.IntegerType.get_signless(64), 167 Type.FLOAT16: ir.F16Type.get(), 168 Type.FLOAT32: ir.F32Type.get(), 169 Type.FLOAT64: ir.F64Type.get(), 170 Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()), 171 Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get()) 172 } 173 return dtype_to_irtype[dtype.kind] 174 175def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer: 176 """Returns the ctype pointer for the given numpy array.""" 177 return ctypes.pointer( 178 ctypes.pointer(runtime.get_ranked_memref_descriptor(array))) 179 180 181class ModeFormat(enum.Enum): 182 """The tensor dimension storage format class. 183 184 We support the TACO API mode_format class with an alias of this class. 185 186 In TACO, a tensor dimension is called a mode and the storage format for a 187 tensor dimension is called a mode format. 188 """ 189 DENSE = sparse_tensor.DimLevelType.dense 190 COMPRESSED = sparse_tensor.DimLevelType.compressed 191 192 193def _mode_format_operation(a: ModeFormat, b: ModeFormat, 194 op: _LogicalOp) -> ModeFormat: 195 """Implements the given operator on ModeFormat.""" 196 return (ModeFormat.COMPRESSED 197 if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED) else 198 ModeFormat.DENSE) 199 200 201def _mode_format_estimator(op: _BinaryOp) -> _ModeFormatOp: 202 """Produces a ModeFormat operator for the given binary operator. 203 204 The ModeFormat operator is used as a heuristic to derive the destination 205 dimension sparsity from the source dimension sparsity. In particular, if the 206 binary operator produces a disjunction of the zero values from its source 207 operands, such as the MUL operator, we return a ModeFormat operator that 208 uses operator.or_. That is, we estimate that a dimension for the MUL 209 operation result to be sparse if either of its source operands is sparse. 210 211 On the other hand, if the binary operator produces a conjunction of the 212 zero values from its source operands, such as the ADD operator, we return 213 a ModeFormat operator that uses operator.and_. In this case, we estimate 214 that a dimension for the ADD operation result to be sparse if both of its 215 source operands are sparse. 216 217 Args: 218 op: A _BinaryOp object representing a supporting operator on tensors. 219 220 Returns: 221 A ModeFormatOp for estimating the destination dimension sparsity from 222 the source dimension sparsity. 223 """ 224 conjunction = functools.partial(_mode_format_operation, op=operator.and_) 225 disjunction = functools.partial(_mode_format_operation, op=operator.or_) 226 return conjunction if op(0, 1) != 0 else disjunction 227 228 229def _all_instance_of(collection: Iterable, cls: Any) -> bool: 230 """Returns true if all elements of the iterable is an instance of cls.""" 231 return all(isinstance(e, cls) for e in collection) 232 233 234def _identity_ordering(rank: int) -> List[int]: 235 """Returns the identity ordering for tensor of given rank.""" 236 return list(range(rank)) 237 238 239@dataclasses.dataclass(frozen=True) 240class ModeOrdering: 241 """The tensor dimension ordering class. 242 243 We support the TACO API mode_ordering class with an alias of this class. 244 245 Attributes: 246 ordering: A list of integers representing the ordering of the tensor 247 dimensions. 248 """ 249 ordering: List[int] 250 251 def __post_init__(self) -> None: 252 """Verifies the value in ordering. 253 254 Raises: 255 ValueError: If ordering is not a list of integers. 256 """ 257 if (not isinstance(self.ordering, list) or 258 not _all_instance_of(self.ordering, int)): 259 raise ValueError("Ordering must be a list of integers: " 260 f"{self.ordering}") 261 # Check that ordering is a permutation of the dimension numbers. 262 if sorted(self.ordering) != _identity_ordering(self.rank()): 263 raise ValueError(f"Invalid ordering: {self.ordering} != " 264 f"permutation{_identity_ordering(self.rank())}.") 265 266 def rank(self) -> int: 267 """Returns the number of dimensions represented by the ordering.""" 268 return len(self.ordering) 269 270 271@dataclasses.dataclass(frozen=True) 272class ModeFormatPack: 273 """The tensor dimension format class. 274 275 We support the TACO API mode_format_pack class with an alias of this class. 276 277 The storage format of a tensor contains one mode_format for each tensor 278 dimension. 279 280 Attributes: 281 formats: A list of ModeFormat representing the storage format for each of 282 the tensor dimension. 283 """ 284 formats: List[ModeFormat] 285 286 def __post_init__(self) -> None: 287 """Verifies the value in formats. 288 289 Raises: 290 ValueError: If formats is not a list of ModeFormats. 291 """ 292 if (not isinstance(self.formats, list) or 293 not _all_instance_of(self.formats, ModeFormat)): 294 raise ValueError("Formats must be a list of ModeFormat: " 295 f"{self.formats}") 296 297 def rank(self) -> int: 298 """Returns the number of dimensions represented by the format pack.""" 299 return len(self.formats) 300 301 302@dataclasses.dataclass 303class Format: 304 """The tensor format class defined by the TACO API. 305 306 Attributes: 307 format_pack: A ModeFormatPack representing the storage format for the tensor 308 dimensions. 309 ordering: A ModeOrdering representing the tensor dimension ordering in the 310 storage. 311 """ 312 format_pack: ModeFormatPack 313 ordering: Optional[ModeOrdering] = None 314 315 def __post_init__(self) -> None: 316 """Verifies and fixes up the values in format_pack and ordering. 317 318 Verifies and fixes up the values in format_pack and ordering to supports the 319 initializer syntax defined by the TACO API. If format_pack is a list of 320 ModeFormat, replaces it with ModeFormatPack constructed from the list. If 321 ordering is not provided, set ordering to the natural ordering for the rank 322 corresponding to format_pack. 323 324 Raises: 325 ValueError: If format_pack is not an instance of ModeFormatPack or if 326 ordering is not an instance of ModeOrdering. 327 """ 328 if isinstance(self.format_pack, list): 329 if not _all_instance_of(self.format_pack, ModeFormat): 330 raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}") 331 self.format_pack = ModeFormatPack(self.format_pack) 332 if not isinstance(self.format_pack, ModeFormatPack): 333 raise ValueError(f"Expected ModeFormatpack: {self.format_pack}") 334 335 if self.ordering is None: 336 self.ordering = ModeOrdering(list(range(self.rank()))) 337 if isinstance(self.ordering, list): 338 if not _all_instance_of(self.ordering, int): 339 raise ValueError(f"Expected a list of integer: {self.ordering}") 340 self.ordering = ModeOrdering(self.ordering) 341 if not isinstance(self.ordering, ModeOrdering): 342 raise ValueError(f"Expected ModeOrdering: {self.ordering}") 343 344 if self.format_pack.rank() != self.ordering.rank(): 345 raise ValueError("Inconsistent ModeFormatPack and ModeOrdering: " 346 f"len({self.format_pack}) != " 347 f"len({self.ordering})") 348 349 def rank(self) -> int: 350 """Returns the number of dimensions represented by the format.""" 351 return self.format_pack.rank() 352 353 def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]: 354 """Constructs the numpy arrays for the permutation and sparsity.""" 355 perm = np.array(self.ordering.ordering, dtype=np.ulonglong) 356 a = [0 if s == ModeFormat.DENSE else 1 for s in self.format_pack.formats] 357 sparse = np.array(a, dtype=np.uint8) 358 return (perm, sparse) 359 360 def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]: 361 """Constructs the MLIR attributes for the tensor format.""" 362 order = ( 363 range(self.rank()) if 364 (self.ordering is None) else self.ordering.ordering) 365 mlir_storage_format = [f.value for f in self.format_pack.formats] 366 return sparse_tensor.EncodingAttr.get(mlir_storage_format, 367 ir.AffineMap.get_permutation(order), 368 _POINTER_BIT_WIDTH, _INDEX_BIT_WIDTH) 369 370 371def _make_format(formats: List[ModeFormat], 372 ordering: Optional[List[int]] = None) -> Format: 373 """Constructs a format from a list of ModeFormat and an optional ordering. 374 375 Args: 376 formats: A list of ModeFormat, one for each dimension of a tensor. 377 ordering: An optional list of integer, for the ordering of the tensor 378 dimensions. When an ordering is not given, the identity ordering is used. 379 380 Returns: 381 A tensor format object. 382 383 Raises: 384 ValueError: If formats is not a list of ModeFormat or the length of formats 385 is not consistent with the len of ordering. 386 """ 387 ordering = ordering or _identity_ordering(len(formats)) 388 return Format(ModeFormatPack(formats), ModeOrdering(ordering)) 389 390 391class IndexExpr(abc.ABC): 392 """The index notation base class. 393 394 We support the TACO API index_expression class with an alias of this class. 395 """ 396 397 def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr": 398 """Verifies the RHS operand and returns a binary expression. 399 400 Args: 401 rhs: The RHS of the binary operation, which could be any Python object 402 from user inputs. 403 op: A _BinaryOp object representing the binary operator. 404 405 Raises: 406 ValueError: If rhs is not an IndexExpr. 407 """ 408 if not isinstance(rhs, IndexExpr): 409 raise ValueError(f"Expected IndexExpr: {rhs}") 410 return _BinaryExpr(op, self, rhs) 411 412 def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr": 413 """Build a unary expression. 414 415 Args: 416 op: A _UnaryOp object representing the unary operation. 417 """ 418 return _UnaryExpr(op, self) 419 420 def __add__(self, rhs) -> "_BinaryExpr": 421 """Defines the operator +. 422 423 Args: 424 rhs: The value being added, which could be any Python object from user 425 inputs. 426 427 Returns: 428 A _BinaryExpr object representing the operation. 429 430 Raises: 431 ValueError: If rhs is not an IndexExpr. 432 """ 433 return self._verify_operand_and_build_expr(rhs, operator.add) 434 435 def __mul__(self, rhs) -> "_BinaryExpr": 436 """Defines the operator *. 437 438 Args: 439 rhs: The value being multiplied, which could be any Python object from 440 user inputs. 441 442 Returns: 443 A _BinaryExpr object representing the operation. 444 445 Raises: 446 ValueError: If rhs is not an IndexExpr. 447 """ 448 return self._verify_operand_and_build_expr(rhs, operator.mul) 449 450 def __abs__(self) -> "_UnaryExpr": 451 """Defines the operator abs. 452 453 Returns: 454 A _UnaryExpr object representing the operation. 455 """ 456 return self._build_unary_expr(operator.abs) 457 458 def __neg__(self) -> "_UnaryExpr": 459 """Defines the operator neg. 460 461 Returns: 462 A _UnaryExpr object representing the operation. 463 """ 464 return self._build_unary_expr(operator.neg) 465 466 def __sub__(self, rhs) -> "_BinaryExpr": 467 """Defines the operator -. 468 469 Args: 470 rhs: The value being subtracted, which could be any Python object from 471 user inputs. 472 473 Returns: 474 A _BinaryExpr object representing the operation. 475 476 Raises: 477 ValueError: If rhs is not an IndexExpr. 478 """ 479 return self._verify_operand_and_build_expr(rhs, operator.sub) 480 481 @abc.abstractmethod 482 def _visit(self, 483 func: _ExprVisitor, 484 args, 485 *, 486 leaf_checker: _SubtreeLeafChecker = None) -> None: 487 """A post-order visitor. 488 489 Args: 490 func: A callable applied to each node in the expression tree. 491 args: The variable-length arguments passed to the callable. These 492 arguments are grouped as an iterable and will be unpacked before passing 493 to the callable. This is to enable the keyword argument only syntax 494 after this argument. 495 leaf_checker: A callable object to identify nodes that should be treated 496 as leaf nodes to support partial tree visiting. 497 """ 498 pass 499 500 @abc.abstractmethod 501 def _emit_expression( 502 self, 503 expr_to_opnd: Dict["IndexExpr", lang.OperandDef], 504 expr_to_info: _ExprInfoDict, 505 ) -> lang.ScalarExpression: 506 """Emits MLIR for the expression tree. 507 508 Args: 509 expr_to_opnd: A dictionary for looking up structured op input operands for 510 the input nodes of the structured op. 511 expr_to_info: A dictionary for looking up code generation information for 512 expressions. 513 514 Returns: 515 A linalg dialect ScalarExpression for the expression. 516 """ 517 pass 518 519 @abc.abstractmethod 520 def dtype(self) -> DType: 521 """Returns the data type for the result of the expression.""" 522 pass 523 524 def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None: 525 """Emits a structured op in the linalg dialect for the expression tree. 526 527 We define a DefineOpcallable in the domain specific language for the linalg 528 dialect and execute the callable to generate the structured op. Self is the 529 root of the expression tree for the structured op. 530 531 Args: 532 expr_to_info: A dictionary for looking up code generation information for 533 expressions. 534 """ 535 op_info = expr_to_info[self].structop_info 536 op_name = op_info.dst_name 537 op_def = lang.LinalgOpDef(name=op_name) 538 op_callable = lang.DefinedOpCallable(op_name, op_def) 539 540 # Collect the input expression nodes for the structured op. 541 expr_inputs = [] 542 self._visit( 543 _gather_structured_op_input, 544 (self, expr_to_info, expr_inputs), 545 leaf_checker=_is_structured_op_leaf, 546 ) 547 548 # Create a linalg structured op operand for each input expression node and 549 # build a dictionary for looking up the information. 550 expr_to_input_opnd = { 551 e: _emit_structured_op_input(e, expr_to_info, op_def) 552 for e in expr_inputs 553 } 554 555 # Emit the expression tree, which produces the value assigned to the 556 # destination tensor. 557 value = self._emit_expression(expr_to_input_opnd, expr_to_info) 558 # Emit the structured op representation for the destination tensor. 559 dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name, 560 lang.OperandKind.OUTPUT_TENSOR) 561 dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices) 562 dst_use = lang.TensorUse(dst_opnd, dst_dim_syms) 563 564 expr_info = expr_to_info[self] 565 # If the structured op reduces some indices, explicitly represent the 566 # reduction. This is done by generating a ReduceFn for the dimensions being 567 # reduced in the linalg dialect and calling the function with the value 568 # being reduced. We only support add reduction currently. 569 if expr_info.reduce_indices: 570 reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices) 571 value = lang.ReduceFn.add[reduce_dims](value) 572 573 # Emit the assignment as a comprehension in the linalg dialect. 574 comp = lang.Comprehension((dst_use, value)) 575 op_def.comprehensions.append(comp) 576 577 # The structured op in the linalg dialect requires an explicit 578 # initialization for the destination tensor. Emit MLIR to initialize the 579 # destination tensor. 580 init = op_info.emit_tensor_init() 581 582 # Collect MLIR values for the linalg input operands, with the assumption 583 # that dictionary preserves the insertion order. 584 args = [ 585 expr_to_info[expr].mlir_value 586 for expr, opnd in expr_to_input_opnd.items() 587 ] 588 # Execute the DefineOpcallable object for the linalg dialect operation to 589 # emit MLIR for the linalg structured op. 590 expr_info.mlir_value = op_callable(*args, outs=[init]) 591 592 def _identify_structured_ops( 593 self, 594 expr_to_info: _ExprInfoDict, 595 dst: "Tensor", 596 dst_indices: Tuple["IndexVar", ...], 597 ) -> List["IndexExpr"]: 598 """Returns expression nodes for the roots of the identified structured ops. 599 600 A structured op in the linalg dialect only supports reduction performed on 601 the whole expression. If the expression tree contains reduction that are 602 performed on part of the expression tree, the expression tree needs to be 603 implemented with multiple structured ops. This routine identifies all the 604 expression nodes that contain reduction as the root of structured ops in the 605 linalg dialect. 606 607 Args: 608 expr_to_info: A dictionary for looking up code generation information for 609 expressions. 610 dst: A destination Tensor that accepts the value of the expression tree. 611 dst_indices: The indices used by the destination index expression. 612 613 Returns: 614 An ordered list of IndexExpr for the root expressions of the structured 615 ops, where child expressions go before parent expressions that use their 616 results. 617 """ 618 reduce_indices = tuple( 619 set(expr_to_info[self].src_indices) - set(dst_indices)) 620 for reduce_index in reduce_indices: 621 _mark_structured_op_root(self, reduce_index, expr_to_info) 622 623 self._visit(_accumulate_reduce_indices, (expr_to_info,)) 624 structop_roots = [] 625 self._visit(_gather_structured_op, (expr_to_info, structop_roots)) 626 627 # Handle the root of the top level expression. 628 if not structop_roots or structop_roots[-1] != self: 629 # The top level expression is not a reduction. Add the top level 630 # expression as a structured op root. 631 structop_roots.append(self) 632 633 # Use user specified information for the destination tensor to build an 634 # _StructOpInfo for the top level expression. 635 expr_to_info[self].structop_info = _StructOpInfo(dst_indices, 636 tuple(dst.shape), 637 dst.dtype, dst.name, 638 dst.format) 639 640 return structop_roots 641 642 def _validate_and_collect_expr_info( 643 self, 644 dst: "Tensor", 645 dst_indices: Tuple["IndexVar", ...], 646 ) -> _ExprInfoDict: 647 """Propagates expression information for validation. 648 649 Propagates the indices used by child expression nodes to parent expression 650 nodes. Also collects and validates the sizes for the dimensions 651 corresponding to the indices. 652 653 Args: 654 dst: A destination Tensor that accepts the value of the expression tree. 655 dst_indices: The indices used by the destination index expression. 656 657 Raises: 658 ValueError if there is any inconsistency in indices or dimensional 659 values. 660 661 Returns: 662 A dictionary of (IndexExpr, _ExprInfo). 663 """ 664 expr_to_info = {} 665 # Validate the expression tree and construct expression information. 666 self._visit(_validate_and_collect_expr_info, (expr_to_info,)) 667 668 # Validate the destination dimension information. 669 info = expr_to_info[self] 670 index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)} 671 for i, d, in zip(dst_indices, dst.shape): 672 if i not in index_to_dim_info: 673 raise ValueError("Destination IndexVar not used in the " 674 f"source expression: {i}") 675 else: 676 if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1: 677 raise ValueError(f"Inconsistent destination dimension for {i}: " 678 f"{d} vs {index_to_dim_info[i].dim}") 679 680 return expr_to_info 681 682 def _emit_assignment( 683 self, 684 module: ir.Module, 685 dst: "Tensor", 686 dst_indices: Tuple["IndexVar", ...], 687 expr_to_info: _ExprInfoDict, 688 input_accesses: List["Access"], 689 ) -> None: 690 """Emits an MLIR function for assigning the expression to a tensor.""" 691 input_types = [a.tensor.mlir_tensor_type() for a in input_accesses] 692 693 # Build the kernel for the operations. 694 with ir.InsertionPoint(module.body): 695 696 @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME) 697 def linalg_funcop(*args): 698 # Set up the mapping from the Access nodes to their MLIR values. 699 for e, mlir in zip(input_accesses, args): 700 expr_to_info[e].mlir_value = mlir 701 702 # Emit structured ops in the linalg dialect to implement the assignment. 703 for structop_root in self._identify_structured_ops( 704 expr_to_info, dst, dst_indices): 705 structop_root._emit_structured_op(expr_to_info) 706 dst._record_stats(expr_to_info[structop_root].structop_info) 707 708 # The function returns the MLIR value of the root expression. 709 return expr_to_info[self].mlir_value 710 711 linalg_funcop.func_op.attributes[ 712 "llvm.emit_c_interface"] = ir.UnitAttr.get() 713 714 def get_input_accesses(self) -> List["Access"]: 715 """Compute the list of input accesses for the expression.""" 716 input_accesses = [] 717 self._visit(_gather_input_accesses_index_vars, (input_accesses,)) 718 return input_accesses 719 720 def compile( 721 self, 722 dst: "Tensor", 723 dst_indices: Tuple["IndexVar", ...], 724 ) -> execution_engine.ExecutionEngine: 725 """Compiles the tensor assignment dst[dst_indices] = expression. 726 727 Args: 728 dst: The destination tensor. 729 dst_indices: The tuple of IndexVar used to access the destination tensor. 730 731 Returns: 732 The execution engine for the tensor assignment. 733 734 Raises: 735 ValueError: If the expression is not proper or not supported. 736 """ 737 expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices) 738 input_accesses = self.get_input_accesses() 739 740 # Build and compile the module to produce the execution engine. 741 with ir.Context(), ir.Location.unknown(): 742 module = ir.Module.create() 743 self._emit_assignment(module, dst, dst_indices, expr_to_info, 744 input_accesses) 745 engine = utils.compile_and_build_engine(module) 746 747 return engine 748 749 750class _AtomicCounter: 751 """An atomic counter.""" 752 753 def __init__(self): 754 self._counter = 0 755 self._counter_lock = threading.Lock() 756 757 def increment(self) -> int: 758 """Increments the counter by one and returns the old value.""" 759 old_value = self._counter 760 with self._counter_lock: 761 self._counter = self._counter + 1 762 return old_value 763 764 765class IndexVar(IndexExpr): 766 """The tensor index class. 767 768 We support the TACO API index_var class with an alias of this class. 769 770 An IndexVar object represents an index variable in tensor index notation. 771 772 Attributes: 773 name: A unique string name of the IndexVar. 774 """ 775 _counter = _AtomicCounter() 776 777 def __init__(self): 778 id = self._counter.increment() 779 self._name = f"{_TACO_INDEX_PREFIX}{id}" 780 781 def __repr__(self) -> str: 782 return f"IndexVar(name={repr(self._name)})" 783 784 @property 785 def name(self) -> str: 786 """Returns the name of the IndexVar.""" 787 return self._name 788 789 def _visit(self, 790 func: _ExprVisitor, 791 args, 792 *, 793 leaf_checker: _SubtreeLeafChecker = None) -> None: 794 """A post-order visitor.""" 795 if leaf_checker: 796 assert leaf_checker(self, *args) 797 func(self, *args) 798 799 def _emit_expression( 800 self, 801 expr_to_opnd: Dict[IndexExpr, lang.OperandDef], 802 expr_to_info: _ExprInfoDict, 803 ) -> lang.ScalarExpression: 804 """Emits a index value casted to the data type of the tensor expression.""" 805 dim = getattr(lang.D, self.name) 806 index = lang.index(dim) 807 int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index) 808 return lang.TypeFn.cast_unsigned(lang.T, int_value) 809 810 def dtype(self) -> DType: 811 """Returns the data type for the index value. 812 813 This is unreachable for IndexVar. 814 """ 815 assert 0 816 817 818def get_index_vars(n: int) -> List[IndexVar]: 819 """Returns a list of n IndexVar. 820 821 This routine is defined by the TACO API. 822 823 Args: 824 n: An integer representing the number of IndexVar to get. 825 826 Returns: 827 A list of IndexVar. 828 829 Raises: 830 ValueError: if n is not a positive integer. 831 """ 832 if not isinstance(n, int) or n <= 0: 833 raise ValueError(f"Expected an integer: {n}.") 834 # If lock contention ever becomes an issue, we could implement a bulk getter 835 # that returns a range by only claiming the lock once. 836 return [IndexVar() for i in range(n)] 837 838 839def _mlir_symbols_from_index_vars( 840 index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.SymbolDef, ...]: 841 """Returns a tuple of MLIR symbols for the given tuple of index_var.""" 842 return tuple(getattr(lang.S, i.name) for i in index_vars) 843 844 845def _mlir_dimensions_from_index_vars( 846 index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.DimDef, ...]: 847 """Returns a tuple of MLIR dimensions for the given tuple of index_var.""" 848 return tuple(getattr(lang.D, i.name) for i in index_vars) 849 850 851def _mlir_tensor_type( 852 dtype: DType, shape: Tuple[int, ...], 853 attr: Optional[sparse_tensor.EncodingAttr]) -> ir.RankedTensorType: 854 """Returns an MLIR tensor type. 855 856 Args: 857 dtype: An DType object for the element data type of the tensor. 858 shape: A tuple of integer for the shape of the tensor. 859 attr: An optional MLIR sparse tensor attribute, only provided if the tensor 860 is a sparse tensor. 861 862 Returns: 863 An MLIR ranked tensor type. 864 """ 865 ir_type = _mlir_type_from_taco_type(dtype) 866 return ir.RankedTensorType.get(shape, ir_type, attr) 867 868 869@dataclasses.dataclass(frozen=True) 870class _StructOpInfo: 871 """Information for generating a structured op in the linalg dialect. 872 873 This information is associated with an expression node that serves as the 874 root for an expression subtree implemented with a structured op. 875 876 Attributes: 877 dst_indices: A tuple of IndexVar, representing the result dimensions of the 878 structured op. This is used to construct the temporary variable for the 879 tensor to hold the structured op result. 880 dst_dims: A tuple of int, representing the result shape of the structured 881 op. 882 dst_dtype: A DType representing the data type of the structured op result. 883 dst_name: A string representing the name of the structured op result. 884 dst_format: An optional Format object representing the destination tensor 885 format. None represents a true dense tensor. 886 """ 887 dst_indices: Tuple[IndexVar, ...] 888 dst_dims: Tuple[int, ...] 889 dst_dtype: DType 890 dst_name: str 891 dst_format: Optional[Format] 892 893 def __post_init__(self) -> None: 894 """Verifies the integrity of the attribute values.""" 895 assert len(self.dst_indices) == len(self.dst_dims) 896 897 def emit_tensor_init(self) -> ir.RankedTensorType: 898 """Returns an initialization for the destination tensor.""" 899 if self.dst_format is None or self.dst_format.rank() == 0: 900 # Initialize the dense tensor. 901 ir_type = _mlir_type_from_taco_type(self.dst_dtype) 902 tensor = linalg.InitTensorOp(self.dst_dims, ir_type).result 903 zero = arith.ConstantOp(ir_type, 0.0) 904 return linalg.fill(zero, outs=[tensor]) 905 906 # Initialize the sparse tensor. 907 mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims, 908 self.dst_format.mlir_tensor_attr()) 909 index_type = ir.IndexType.get() 910 return bufferization.AllocTensorOp(mlir_type, [], None, None) 911 912 913class _Stats: 914 """Information to describe how a tensor expression is implemented. 915 916 Currently, we only record the temporary tensors introduced for splitting the 917 original expression. 918 """ 919 920 def __init__(self): 921 self._temps = [] 922 923 def __repr__(self) -> str: 924 return f"_Stats({repr(self._temps)})" 925 926 def add_element(self, structop: _StructOpInfo): 927 """Adds a temporary tensor.""" 928 self._temps.append(structop) 929 930 def get_total(self) -> int: 931 """Gets the total number of temporary tensors.""" 932 return len(self._temps) 933 934 def _get_element(self, idx: int) -> _StructOpInfo: 935 """Gets the ith temporary tensor.""" 936 assert idx < self.get_total() 937 return self._temps[idx] 938 939 def get_dimensions(self, idx: int) -> Tuple[int]: 940 """Gets the dimensions for the ith temporary tensor.""" 941 return self._get_element(idx).dst_dims 942 943 def get_formats(self, idx: int) -> Tuple[ModeFormat]: 944 """Gets the ModeFormats for the ith temporary tensor.""" 945 return tuple(self._get_element(idx).dst_format.format_pack.formats) 946 947 948class _SparseValueInfo(enum.Enum): 949 """Describes how a sparse tensor value is stored. 950 _UNPACKED: The sparse tensor value is stored as (coordnates, values) in 951 Python. 952 _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR 953 sparse tensor. 954 """ 955 _UNPACKED = 0 956 _PACKED = 1 957 958 959@dataclasses.dataclass(frozen=True) 960class _Assignment: 961 """Records an assignment to a tensor T as T[indices] = expression.""" 962 indices: Tuple["IndexVar", ...] 963 expression: "IndexExpr" 964 965 966class Tensor: 967 """The tensor class. 968 969 We support the TACO API tensor class with an alias of this class. 970 971 This class is part of the TACO API with the following methods: 972 insert: Inserts a value to the given coordinate in the tensor. 973 to_array: Returns a numpy ndarray for the tensor. 974 975 TACO API also defines the following arrtibutes for the class: 976 dtype: A dtype object representing the data type of the tensor. 977 format: A format object representing the storage format of the tensor. 978 name: A string object representing the name of the tensor. 979 order: An integral rank of the tensor. 980 shape: A list of integers representing the shape of the tensor. 981 982 We currently ignore the tensor dimension ordering for dense tensor. 983 """ 984 _counter = _AtomicCounter() 985 986 def _get_unique_name(self) -> str: 987 """Returns a unique name for creating a new Tensor.""" 988 return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}" 989 990 def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat], 991 Format]) -> None: 992 """Process the fmt argument for the Tensor constructor. 993 994 Args: 995 fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If 996 this argument is a ModeFormat, uses this ModeFormat for all the tensor 997 dimensions. If this argument is a list of ModeFormat, the len of the 998 list should equal to the rank of the tensor. If this argument is a 999 format, uses it for the format of the tensor. 1000 1001 Raises: 1002 ValueError: If fmt is not one of the expected type or is inconsistent 1003 with the rank of the tensor. This is because fmt could be an users 1004 input. 1005 """ 1006 if isinstance(fmt, ModeFormat): 1007 self._format = _make_format([fmt] * self.order) 1008 elif isinstance(fmt, list): 1009 if len(fmt) == self.order and isinstance(fmt[0], ModeFormat): 1010 self._format = _make_format(fmt) 1011 else: 1012 raise ValueError("Inconsistent shape and format: " 1013 f"{self._shape}, {fmt}.") 1014 elif isinstance(fmt, Format): 1015 if fmt.rank() != self.order: 1016 raise ValueError("Inconsistent shape and format: " 1017 f"{self._shape}, {fmt}.") 1018 else: 1019 self._format = fmt 1020 else: 1021 raise ValueError(f"Invalid format argument: {fmt}.") 1022 1023 def __init__(self, 1024 value_or_shape: Optional[Union[List[int], Tuple[int, ...], 1025 complex, float, int]] = None, 1026 fmt: Optional[Union[ModeFormat, List[ModeFormat], 1027 Format]] = None, 1028 dtype: Optional[DType] = None, 1029 name: Optional[str] = None, 1030 is_dense: bool = False): 1031 """The tensor constructor interface defined by TACO API. 1032 1033 Args: 1034 value_or_shape: This argument is optional and can be int, float, 1035 List[int], or Tuple[int, ...]. If this argument is an int or float, 1036 creates a scalar tensor and initializes it with the value. If this 1037 argument is a list or tuple of int, uses it as the shape to create a 1038 tensor. 1039 fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If 1040 this argument is a ModeFormat, uses this ModeFormat for all the tensor 1041 dimensions. If this argument is a list of ModeFormat, the len of the 1042 list should equal to the rank of the tensor. If this argument is a 1043 format, uses it for the format of the tensor. 1044 dtype: An object of dtype, representing the data type of the tensor. 1045 name: A string name of the tensor. If a name is not given, creates a 1046 unique name for the tensor. 1047 is_dense: A boolean variable to indicate whether the tensor is a dense 1048 tensor without any sparsity annotation. 1049 1050 Raises: 1051 ValueError: If there is any inconsistency among the input arguments. 1052 """ 1053 # Take care of the argument default values common to both sparse tensors 1054 # and dense tensors. 1055 dtype = dtype or DType(Type.FLOAT32) 1056 self._name = name or self._get_unique_name() 1057 self._assignment = None 1058 self._engine = None 1059 self._sparse_value_location = _SparseValueInfo._UNPACKED 1060 self._dense_storage = None 1061 self._dtype = dtype 1062 1063 if is_dense: 1064 assert (fmt is None) 1065 assert (isinstance(value_or_shape, tuple) or isinstance( 1066 value_or_shape, list)) and _all_instance_of(value_or_shape, int) 1067 self._shape = value_or_shape 1068 self._format = None 1069 return 1070 1071 fmt = fmt or ModeFormat.COMPRESSED 1072 # We currently use _coords and _values to host the sparse tensor value with 1073 # COO format, and _dense_storage to host the dense tensor value. We don't 1074 # support the conversion between the two storages. 1075 self._coords = [] 1076 self._values = [] 1077 self._stats = _Stats() 1078 if value_or_shape is None or isinstance(value_or_shape, int) or isinstance( 1079 value_or_shape, float) or isinstance(value_or_shape, complex): 1080 # Create a scalar tensor and ignore the fmt parameter. 1081 self._shape = [] 1082 self._format = _make_format([], []) 1083 if value_or_shape is not None: 1084 self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value) 1085 elif (isinstance(value_or_shape, tuple) or isinstance( 1086 value_or_shape, list)) and _all_instance_of(value_or_shape, int): 1087 # Create a tensor with the specified shape and format. 1088 self._shape = list(value_or_shape) 1089 self._init_format(fmt) 1090 else: 1091 raise ValueError("Invalid first argument. " 1092 "Must be a tuple or list for a shape or a single value" 1093 f"if initializing a scalar tensor: {value_or_shape}.") 1094 1095 def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None: 1096 """Records the MLIR sparse tensor pointer.""" 1097 self._sparse_value_location = _SparseValueInfo._PACKED 1098 self._packed_sparse_value = pointer 1099 1100 def is_unpacked(self) -> bool: 1101 """Returns true if the tensor value is not packed as MLIR sparse tensor.""" 1102 return (self._sparse_value_location == _SparseValueInfo._UNPACKED) 1103 1104 def unpack(self) -> None: 1105 """Unpacks the MLIR sparse tensor representation.""" 1106 if self.is_dense() or self.is_unpacked(): 1107 return 1108 1109 # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored 1110 # values and verify the values. 1111 rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor( 1112 self._packed_sparse_value, self._dtype.value) 1113 assert rank == self.order 1114 assert np.array_equal(self.shape, shape) 1115 assert nse == len(values) 1116 self._coords = indices 1117 self._values = values 1118 self._sparse_value_location = _SparseValueInfo._UNPACKED 1119 1120 def __repr__(self) -> str: 1121 self._sync_value() 1122 self.unpack() 1123 value_str = (f"{repr(self._dense_storage)})" if self.is_dense() else 1124 f"{repr(self._coords)} {repr(self._values)})") 1125 return (f"Tensor(_name={repr(self._name)} " 1126 f"_dtype={repr(self._dtype)} : ") + value_str 1127 1128 def insert(self, coords: List[int], val: Union[complex, float, int]) -> None: 1129 """Inserts a value to the given coordinate. 1130 1131 Args: 1132 coords: A list of integer coordinates. The length of the list must be the 1133 same as the rank of the tensor. 1134 val: A value being inserted. It is either an integral or a floating point 1135 value. This value will be converted to the data type of the tensor. 1136 1137 Raises: 1138 ValueError: When there is any problem in the parameters. 1139 """ 1140 if self.is_dense(): 1141 raise ValueError("Insert method is not supported for dense tensors.") 1142 if self._assignment != None or not self.is_unpacked(): 1143 raise ValueError( 1144 "Can't use Insert method for a tensor constructed from a file.") 1145 if not isinstance(coords, list): 1146 raise ValueError(f"Non list coordinate detected: {coords}.") 1147 if not _all_instance_of(coords, int): 1148 raise ValueError(f"Non integer coordinate detected: {coords}.") 1149 if (len(coords) != self.order or 1150 any([c < 0 or c >= self._shape[i] for i, c in enumerate(coords)])): 1151 raise ValueError("Invalid coordinate for rank: " 1152 f"{self.order}, {coords}.") 1153 1154 if not isinstance(val, int) and not isinstance( 1155 val, float) and not isinstance(val, complex): 1156 raise ValueError(f"Value is neither int nor float: {val}.") 1157 1158 self._coords.append(tuple(coords)) 1159 self._values.append(self._dtype.value(val)) 1160 1161 def is_dense(self) -> bool: 1162 """Returns true if the tensor doesn't have sparsity annotation.""" 1163 return self.order == 0 or self._format is None 1164 1165 def to_array(self) -> np.ndarray: 1166 """Returns the numpy array for the Tensor. 1167 1168 This is currenly only implemented for dense Tensor. 1169 """ 1170 if not self.is_dense(): 1171 raise ValueError("Conversion from non-dense Tensor " 1172 "to numpy array not supported yet.") 1173 1174 self._sync_value() 1175 1176 return self._dense_storage 1177 1178 @staticmethod 1179 def from_array(array: np.ndarray) -> "Tensor": 1180 """Returns a dense tensor with the value copied from the input array. 1181 1182 We currently only support the conversion of float32 and float64 numpy arrays 1183 to Tensor. 1184 1185 Args: 1186 array: The numpy array that provides the data type, shape and value for 1187 the tensor. 1188 1189 Returns: 1190 A Tensor object. 1191 1192 Raises: 1193 ValueError if the data type of the numpy array is not supported. 1194 """ 1195 if array.dtype != np.float32 and array.dtype != np.float64: 1196 raise ValueError(f"Expected floating point value type: {array.dtype}.") 1197 tensor = Tensor( 1198 array.shape, 1199 dtype=_nptype_to_taco_type(array.dtype.type), 1200 is_dense=True) 1201 tensor._dense_storage = np.copy(array) 1202 return tensor 1203 1204 @staticmethod 1205 def from_coo( 1206 coordinates: List[Tuple[int, ...]], 1207 values: List[_AnyRuntimeType], 1208 fmt: Format, 1209 dtype: DType, 1210 ) -> "Tensor": 1211 """Converts coordinates and values to a sparse tensor representation. 1212 1213 Args: 1214 coordinates: A list of coordinates with non-zero values. 1215 values: The non-zero values. 1216 fmt: The tensor storage format. 1217 dtype: The tensor element data type. 1218 1219 Returns: 1220 A tensor with the given non-zero values and storage format. The shape of 1221 the tensor has the minimum size for each dimension to make the given 1222 coordinates valid. 1223 """ 1224 assert (isinstance(coordinates, List) and 1225 _all_instance_of(coordinates, Tuple)) 1226 assert (isinstance(values, List) and _all_instance_of(values, dtype.value)) 1227 assert isinstance(fmt, Format) 1228 1229 rank = fmt.rank() 1230 assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates) 1231 1232 # Find the maximum coordinate value for each dimension. 1233 max_coordinate = list(map(max, zip(*coordinates))) 1234 # The size of each dimension is one more that such a maximum coordinate 1235 # value. 1236 shape = [c + 1 for c in max_coordinate] 1237 tensor = Tensor(shape, fmt, dtype=dtype) 1238 tensor._coords = coordinates 1239 tensor._values = values 1240 1241 return tensor 1242 1243 @staticmethod 1244 def from_file( 1245 filename: str, 1246 fmt: Format, 1247 dtype: DType, 1248 ) -> "Tensor": 1249 """Constructs a sparse tensor using the COO-flavored values from a file. 1250 1251 Args: 1252 filename: A string for the name of the file that contains the sparse 1253 tensor data. 1254 fmt: The tensor storage format. 1255 dtype: The tensor element data type. 1256 1257 Returns: 1258 A tensor with the given non-zero values and storage format. The tensor 1259 value is stored as an MLIR sparse tensor. 1260 """ 1261 sparse_tensor, shape = utils.create_sparse_tensor(filename, 1262 fmt.format_pack.formats, 1263 _dtype_to_mlir_str(dtype)) 1264 tensor = Tensor(shape.tolist(), fmt, dtype=dtype) 1265 tensor._set_packed_sparse_tensor(sparse_tensor) 1266 1267 return tensor 1268 1269 def to_file(self, filename: str) -> None: 1270 """Output the tensor value to a file. 1271 1272 This method evaluates any pending assignment to the tensor and outputs the 1273 tensor value. 1274 1275 Args: 1276 filename: A string file name. 1277 1278 Raises: 1279 ValueError: If the tensor is dense, or an unpacked sparse tensor. 1280 """ 1281 self._sync_value() 1282 1283 if self.is_dense(): 1284 raise ValueError("Writing dense tensors without sparsity annotation to " 1285 "file is not supported.") 1286 1287 if self.is_unpacked(): 1288 raise ValueError("Writing unpacked sparse tensors to file is not " 1289 "supported.") 1290 1291 utils.output_sparse_tensor(self._packed_sparse_value, filename, 1292 self._format.format_pack.formats, 1293 _dtype_to_mlir_str(self._dtype)) 1294 1295 @property 1296 def dtype(self) -> DType: 1297 """Returns the data type for the Tensor.""" 1298 return self._dtype 1299 1300 @property 1301 def format(self) -> Format: 1302 """Returns the storage format for the Tensor.""" 1303 return self._format 1304 1305 @property 1306 def name(self) -> str: 1307 """Returns the name for the Tensor.""" 1308 return self._name 1309 1310 @property 1311 def order(self) -> int: 1312 """Returns the rank of the Tensor.""" 1313 return len(self._shape) 1314 1315 @property 1316 def shape(self) -> List[int]: 1317 """Returns the shape of the Tensor.""" 1318 return self._shape 1319 1320 def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]: 1321 """Verifies and normalizes the indices to access the tensor. 1322 1323 Args: 1324 indices: The index expression used to access a tensor, which could be any 1325 Python object from user inputs. 1326 1327 Returns: 1328 A tuple of IndexVar. 1329 1330 Raises: 1331 ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or 1332 a tuple of IndexVar for other tensors. 1333 """ 1334 if self.order == 0: 1335 if not isinstance(indices, int) or indices != 0: 1336 raise ValueError(f"Expected 0 to index scalar tensors: {indices}") 1337 return () 1338 1339 if isinstance(indices, IndexVar): 1340 return (indices,) 1341 elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar): 1342 return indices 1343 1344 raise ValueError(f"Expected IndexVars: {indices}") 1345 1346 def __getitem__(self, key) -> "Access": 1347 """Verifies and processes a tensor access. 1348 1349 In the tensor index notation, a tensor access T[i, j] is represented as 1350 retrieving a value with key (i, j) from the tensor object T in Python. This 1351 routine verifies the key for the tensor access and returns a tensor access 1352 object. 1353 1354 Args: 1355 key: The key used to access the tensor, which could be any Python object 1356 from user inputs. 1357 1358 Returns: 1359 The corresponding tensor access object. 1360 1361 Raises: 1362 ValueError: If key is not an IndexVar or a tuple of IndexVar. 1363 """ 1364 indices = self._verify_and_normalize_indices(key) 1365 return Access(self, indices) 1366 1367 def __setitem__(self, key, value) -> None: 1368 """Verifies and processes a tensor assignment. 1369 1370 In the tensor index notation, a tensor assignment "T[i, j] = ..." is 1371 represented as setting a value for a tensor object T via key (i, j) in 1372 Python. This routine verifies the key, evaluates the value, and assigns the 1373 value to the tensor. 1374 1375 We only support assignment of dense tensor currently. 1376 1377 Args: 1378 key: The key used to access the tensor, which could be any Python object 1379 from user inputs. 1380 value: The value assigned to the tensor, which could be any Python object 1381 from user inputs. 1382 1383 Raises: 1384 ValueError: If tensor is not a dense tensor, or the key is not an IndexVar 1385 or a tuple of IndexVar, or the length of the indices is not the same as 1386 the rank of the tensor. 1387 """ 1388 indices = self._verify_and_normalize_indices(key) 1389 if len(indices) != self.order: 1390 raise ValueError("Mismatch between indices and tensor rank: " 1391 f"len({indices}) != {self.order}.") 1392 1393 self._assignment = _Assignment(indices, value) 1394 self._engine = None 1395 1396 def compile(self, force_recompile: bool = False) -> None: 1397 """Compiles the tensor assignment to an execution engine. 1398 1399 Calling compile the second time does not do anything unless 1400 force_recompile is True. 1401 1402 Args: 1403 force_recompile: A boolean value to enable recompilation, such as for the 1404 purpose of timing. 1405 1406 Raises: 1407 ValueError: If the assignment is not proper or not supported. 1408 """ 1409 if self._assignment is None or (self._engine is not None and 1410 not force_recompile): 1411 return 1412 1413 self._engine = self._assignment.expression.compile(self, 1414 self._assignment.indices) 1415 1416 def compute(self) -> None: 1417 """Executes the engine for the tensor assignment. 1418 1419 Raises: 1420 ValueError: If the assignment hasn't been compiled yet. 1421 """ 1422 if self._assignment is None: 1423 return 1424 1425 if self._engine is None: 1426 raise ValueError("Need to invoke compile() before invoking compute().") 1427 1428 input_accesses = self._assignment.expression.get_input_accesses() 1429 # Gather the pointers for the input buffers. 1430 input_pointers = [a.tensor.ctype_pointer() for a in input_accesses] 1431 if self.is_dense(): 1432 # The pointer to receive dense output is the first argument to the 1433 # execution engine. 1434 arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers 1435 else: 1436 # The pointer to receive the sparse tensor output is the last argument 1437 # to the execution engine and is a pointer to pointer of char. 1438 arg_pointers = input_pointers + [ 1439 ctypes.pointer(ctypes.pointer(ctypes.c_char(0))) 1440 ] 1441 1442 # Invoke the execution engine to run the module. 1443 self._engine.invoke(_ENTRY_NAME, *arg_pointers) 1444 1445 # Retrieve the result. 1446 if self.is_dense(): 1447 result = runtime.ranked_memref_to_numpy(arg_pointers[0][0]) 1448 assert isinstance(result, np.ndarray) 1449 self._dense_storage = result 1450 else: 1451 self._set_packed_sparse_tensor(arg_pointers[-1][0]) 1452 1453 self._assignment = None 1454 self._engine = None 1455 1456 def evaluate(self) -> None: 1457 """Evaluates the tensor assignment.""" 1458 self.compile() 1459 self.compute() 1460 1461 def _sync_value(self) -> None: 1462 """Updates the tensor value by evaluating the pending assignment.""" 1463 if self._assignment is not None: 1464 self.evaluate() 1465 1466 def mlir_tensor_type(self) -> ir.RankedTensorType: 1467 """Returns the MLIR type for the tensor.""" 1468 mlir_attr = (None if (self._format is None or self.order == 0) else 1469 self._format.mlir_tensor_attr()) 1470 return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr) 1471 1472 def dense_dst_ctype_pointer(self) -> ctypes.pointer: 1473 """Returns the ctypes pointer for the pointer to an MemRefDescriptor. 1474 1475 For a dense tensor output, the MLIR compiler allocates the storage for 1476 the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for 1477 receiving the tensor. 1478 """ 1479 assert self.is_dense() 1480 mem_ref_desc = runtime.make_nd_memref_descriptor( 1481 self.order, np.ctypeslib.as_ctypes_type(self.dtype.value))() 1482 return ctypes.pointer(ctypes.pointer(mem_ref_desc)) 1483 1484 def ctype_pointer(self) -> ctypes.pointer: 1485 """Returns the ctypes pointer for the pointer to the input tensor.""" 1486 if self.is_dense(): 1487 if self._dense_storage is None: 1488 self._dense_storage = np.zeros(self._shape, self._dtype.value) 1489 return _ctype_pointer_from_array(self._dense_storage) 1490 1491 if self.is_unpacked(): 1492 shape = np.array(self._shape, np.int64) 1493 indices = np.array(self._coords, np.int64) 1494 values = np.array(self._values, self._dtype.value) 1495 perm, sparse = self.format.get_permutation_and_sparsity() 1496 ptr = utils.coo_tensor_to_sparse_tensor(shape, values, indices, perm, 1497 sparse) 1498 else: 1499 ptr = self._packed_sparse_value 1500 1501 return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p)) 1502 1503 def get_scalar_value(self) -> _AnyRuntimeType: 1504 """Returns the value for the scalar tensor. 1505 1506 This method also evaluates the assignment to the tensor. 1507 1508 Raises: 1509 ValueError: If the tensor is not a scalar. 1510 """ 1511 if self.order != 0: 1512 raise ValueError(f"Expected a scalar tensor, got: rank={self.order}") 1513 1514 self._sync_value() 1515 return self._dense_storage 1516 1517 1518 def get_coordinates_and_values( 1519 self) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]: 1520 """Returns the coordinates and values for the non-zero elements. 1521 1522 This method also evaluates the assignment to the tensor and unpack the 1523 sparse tensor. 1524 """ 1525 self._sync_value() 1526 1527 if not self.is_dense(): 1528 self.unpack() 1529 return (self._coords, self._values) 1530 1531 if self.order == 0: 1532 return ([], self._dense_storage) 1533 1534 # Coordinates for non-zero elements, grouped by dimensions. 1535 coords_by_dims = self._dense_storage.nonzero() 1536 # Coordinates for non-zero elements, grouped by elements. 1537 coords = np.transpose(coords_by_dims) 1538 values = self._dense_storage[coords_by_dims] 1539 return (coords, values) 1540 1541 def _record_stats(self, structop: "_StructOpInfo"): 1542 """Collects information for temporary tensors.""" 1543 # Exclude user specified destination tensors. 1544 if structop.dst_name == self.name: 1545 return 1546 1547 self._stats.add_element(structop) 1548 1549 1550def _emit_operand(op_def: lang.LinalgOpDef, indices: Tuple[IndexVar, ...], 1551 name: str, kind: lang.OperandKind) -> lang.OperandDef: 1552 """Emits an operand for a tensor access in the current linalg operation. 1553 1554 Args: 1555 op_def: A LinalgOpDef representing the current linalg dialect operation. 1556 indices: A tuple of IndexVar used to access the tensor. 1557 name: A unique string name of the tensor. 1558 kind: An OperandKind for the operand. 1559 1560 Returns: 1561 An OperandDef representing the operand. 1562 """ 1563 dim_sym = _mlir_symbols_from_index_vars(indices) 1564 opnd = lang.OperandDef(kind, lang.T, dim_sym) 1565 op_def.add_operand(name, opnd) 1566 return opnd 1567 1568 1569@dataclasses.dataclass(frozen=True) 1570class _DimInfo: 1571 """Information for an operand dimension. 1572 1573 Attributes: 1574 dim: An integer for the size of the dimension. 1575 mode_format: A ModeFormat for the dimension sparsity. 1576 """ 1577 dim: int 1578 mode_format: ModeFormat 1579 1580 1581def _get_dummy_dim_info() -> _DimInfo: 1582 """Constructs the _DimInfo for an index used in tensor expressions.""" 1583 return _DimInfo(-1, ModeFormat.DENSE) 1584 1585 1586@dataclasses.dataclass() 1587class _ExprInfo: 1588 """Expression information for validation and code generation. 1589 1590 Attributes: 1591 src_indices: A tuple of IndexVar for the indices used by the tensors in the 1592 expression tree. 1593 dim_infos: A tuple of _DimInfo, representing the dimension information 1594 corresponding to the src_indices. 1595 reduce_indices: A set of IndexVar for the indices reduced by the expression. 1596 acc_reduce_indices: An accumulated set of IndexVar for the indices reduced 1597 by the expression and its children. 1598 structop_info: Information to support the code generation for a structured 1599 op in the linalg dialect, if the corresponding expression node is the root 1600 of a subtree for a structured op. 1601 mlir_value: The MLIR value generated for the structured op. 1602 """ 1603 src_indices: Tuple[IndexVar, ...] 1604 dim_infos: Tuple[_DimInfo, ...] 1605 reduce_indices: Optional[Set[IndexVar]] = None 1606 acc_reduce_indices: Optional[Set[IndexVar]] = None 1607 structop_info: Optional[_StructOpInfo] = None 1608 mlir_value: Optional[ir.Value] = None 1609 1610 def __post_init__(self) -> None: 1611 """Verifies and fix up attribute values. 1612 1613 Verifies the consistency of the attributes and modifies the default values 1614 to support convenient initializer syntax. 1615 """ 1616 assert len(self.src_indices) == len(self.dim_infos) 1617 self.reduce_indices = self.reduce_indices or set() 1618 self.acc_reduce_indices = self.acc_reduce_indices or set() 1619 1620 1621@dataclasses.dataclass(frozen=True) 1622class Access(IndexExpr): 1623 """The tensor access class. 1624 1625 We support the TACO API access class with an alias of this class. 1626 1627 Attributes: 1628 tensor: A Tensor being accessed. 1629 indices: A tuple of IndexVar, representing the indices used to access the 1630 Tensor. 1631 """ 1632 tensor: Tensor 1633 indices: Tuple[IndexVar, ...] 1634 1635 def __post_init__(self) -> None: 1636 """Verifies the tensor and indices for a tensor access. 1637 1638 Raises: 1639 ValueError: If indices is not a list of IndexVar or the len of indices 1640 doesn't equal to the rank of the tensor. 1641 """ 1642 if (not isinstance(self.indices, tuple) or 1643 not _all_instance_of(self.indices, IndexVar)): 1644 raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.") 1645 if self.tensor.order != len(self.indices): 1646 raise ValueError("Invalid indices for rank: " 1647 f"str{self.tensor.order} != len({str(self.indices)}).") 1648 1649 def __repr__(self) -> str: 1650 # The Tensor __repr__ method evaluates the pending assignment to the tensor. 1651 # We want to define the __repr__ method here to avoid such evaluation of the 1652 # tensor assignment. 1653 indices_str = ", ".join(map(lambda i: i.name, self.indices)) 1654 return (f"Tensor({self.tensor.name}) " f"Indices({indices_str})") 1655 1656 def _emit_expression( 1657 self, 1658 expr_to_opnd: Dict[IndexExpr, lang.OperandDef], 1659 expr_to_info: _ExprInfoDict, 1660 ) -> lang.ScalarExpression: 1661 """Emits a linalg dialect TensorUse expression for the tensor access.""" 1662 assert self in expr_to_opnd 1663 dims = _mlir_dimensions_from_index_vars(self.indices) 1664 return lang.TensorUse(expr_to_opnd[self], dims) 1665 1666 def _visit(self, 1667 func: _ExprVisitor, 1668 args, 1669 *, 1670 leaf_checker: _SubtreeLeafChecker = None) -> None: 1671 if leaf_checker: 1672 assert leaf_checker(self, *args) 1673 func(self, *args) 1674 1675 def dtype(self) -> DType: 1676 return self.tensor.dtype 1677 1678 1679def _gather_input_accesses_index_vars( 1680 expr: IndexExpr, 1681 input_accesses: List[Access], 1682) -> None: 1683 """Collects Access nodes.""" 1684 if isinstance(expr, Access) and expr not in input_accesses: 1685 input_accesses.append(expr) 1686 1687 1688def _op_ceil(__a: Any) -> Any: 1689 """A _UnaryOp object for operation ceil.""" 1690 pass 1691 1692 1693def _op_floor(__a: Any) -> Any: 1694 """A _UnaryOp object for operation floor.""" 1695 pass 1696 1697 1698def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType: 1699 """Returns the linalg dialect function object for the given operation.""" 1700 op_to_callable = { 1701 operator.abs: lang.UnaryFn.abs, 1702 operator.neg: lang.UnaryFn.negf, 1703 _op_ceil: lang.UnaryFn.ceil, 1704 _op_floor: lang.UnaryFn.floor, 1705 } 1706 return op_to_callable[op] 1707 1708 1709@dataclasses.dataclass(frozen=True) 1710class _UnaryExpr(IndexExpr): 1711 """The representation for a Unary operation. 1712 1713 Attributes: 1714 op: A _UnaryOp representing the operation. 1715 a: An IndexExpr representing the operand for the operation. 1716 """ 1717 op: _BinaryOp 1718 a: IndexExpr 1719 1720 def __post_init__(self) -> None: 1721 """Verifies that the operand being added is an IndexExpr.""" 1722 assert isinstance(self.a, IndexExpr) 1723 1724 def _emit_expression( 1725 self, 1726 expr_to_opnd: Dict[IndexExpr, lang.OperandDef], 1727 expr_to_info: _ExprInfoDict, 1728 ) -> lang.ScalarExpression: 1729 """Emits the expression tree and returns the expression.""" 1730 # The current expression node is an internal node of the structured op. 1731 if self not in expr_to_opnd: 1732 a = self.a._emit_expression(expr_to_opnd, expr_to_info) 1733 return _op_unary_to_callable(self.op)(a) 1734 1735 # The current expression is a leaf node of the structured op. That is, it is 1736 # a temporary tensor generated by its child structured op. 1737 op_info = expr_to_info[self].structop_info 1738 assert op_info is not None 1739 dims = _mlir_dimensions_from_index_vars(op_info.dst_indices) 1740 return lang.TensorUse(expr_to_opnd[self], dims) 1741 1742 def _visit(self, 1743 func: _ExprVisitor, 1744 args, 1745 *, 1746 leaf_checker: _SubtreeLeafChecker = None) -> None: 1747 """A post-order visitor.""" 1748 if leaf_checker is None or not leaf_checker(self, *args): 1749 self.a._visit(func, args, leaf_checker=leaf_checker) 1750 func(self, *args) 1751 1752 def dtype(self) -> DType: 1753 """Returns the data type of the operation.""" 1754 return self.a.dtype() 1755 1756 1757def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType: 1758 """Returns the linalg dialect function object for the given operation.""" 1759 op_to_callable = { 1760 operator.add: lang.BinaryFn.add, 1761 operator.sub: lang.BinaryFn.sub, 1762 operator.mul: lang.BinaryFn.mul, 1763 } 1764 return op_to_callable[op] 1765 1766@dataclasses.dataclass(frozen=True) 1767class _BinaryExpr(IndexExpr): 1768 """The representation for a binary operation. 1769 1770 Attributes: 1771 op: A _BinaryOp representing the binary operation. 1772 a: An IndexExpr representing the first operand of the operation. 1773 b: An IndexExpr representing the second operand of the operation. 1774 """ 1775 op: _BinaryOp 1776 a: IndexExpr 1777 b: IndexExpr 1778 1779 def __post_init__(self) -> None: 1780 """Verifies that the operands being added are IndexExpr.""" 1781 assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr) 1782 1783 def _emit_expression( 1784 self, 1785 expr_to_opnd: Dict[IndexExpr, lang.OperandDef], 1786 expr_to_info: _ExprInfoDict, 1787 ) -> lang.ScalarExpression: 1788 """Emits the expression tree and returns the expression.""" 1789 # The current expression node is an internal node of the structured op. 1790 if self not in expr_to_opnd: 1791 a = self.a._emit_expression(expr_to_opnd, expr_to_info) 1792 b = self.b._emit_expression(expr_to_opnd, expr_to_info) 1793 return _op_to_callable(self.op)(a, b) 1794 1795 # The current expression is a leaf node of the structured op. That is, it is 1796 # a temporary tensor generated by its child structured op. 1797 op_info = expr_to_info[self].structop_info 1798 assert op_info is not None 1799 dims = _mlir_dimensions_from_index_vars(op_info.dst_indices) 1800 return lang.TensorUse(expr_to_opnd[self], dims) 1801 1802 def _visit(self, 1803 func: _ExprVisitor, 1804 args, 1805 *, 1806 leaf_checker: _SubtreeLeafChecker = None) -> None: 1807 """A post-order visitor.""" 1808 if leaf_checker is None or not leaf_checker(self, *args): 1809 self.a._visit(func, args, leaf_checker=leaf_checker) 1810 self.b._visit(func, args, leaf_checker=leaf_checker) 1811 func(self, *args) 1812 1813 def dtype(self) -> DType: 1814 """Returns the data type of the binary operation.""" 1815 return self.a.dtype() 1816 1817 1818def _validate_and_collect_dim_info( 1819 index_to_dim_info: Dict[IndexVar, _DimInfo], 1820 indices: Tuple[IndexVar, ...], 1821 dim_infos: Tuple[_DimInfo, ...], 1822 expr: _BinaryExpr, 1823) -> None: 1824 """Validates and collects the dimension information for an index notation. 1825 1826 Validates (indices, dim_infos) against the information collected from other 1827 source operands and is represented by index_to_dim_info. In particular, we 1828 ensure that each IndexVar corresponds to only one dimension size. We also 1829 aggregate the new information represented in (indices, dim_infos) to 1830 index_to_dim_info. 1831 1832 Args: 1833 index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the 1834 previous operands. 1835 indices: The IndexVars to be validated. 1836 dim_infos: The dimension information for the IndexVars to be validated. 1837 expr: The binary expression where (indices, dim_infos) is used. 1838 1839 Raises: 1840 ValueError if there is any problem in the IndexVars or dimensional values. 1841 """ 1842 assert len(indices) == len(dim_infos) 1843 for i, d in zip(indices, dim_infos): 1844 if i not in index_to_dim_info: 1845 index_to_dim_info[i] = d 1846 else: 1847 dim = index_to_dim_info[i].dim 1848 if dim == -1 or d.dim == -1: 1849 dim = dim if dim != -1 else d.dim 1850 elif dim != d.dim: 1851 raise ValueError(f"Inconsistent source dimension for {i}: " 1852 f"{d.dim} vs {dim}") 1853 mode_format = _mode_format_estimator(expr.op)( 1854 index_to_dim_info[i].mode_format, d.mode_format) 1855 index_to_dim_info[i] = _DimInfo(d.dim, mode_format) 1856 1857 1858def _validate_and_collect_expr_info( 1859 expr: IndexExpr, 1860 expr_to_info: _ExprInfoDict, 1861) -> None: 1862 """Validates dimension information and constructs _ExprInfo. 1863 1864 Validates that dimensional values for the same IndexVar are the same. Collects 1865 a list of IndexVar used by the expression and their corresponding dimensional 1866 values. Constructs an _ExprInfo object to record the information for the 1867 IndexExpr. 1868 1869 This routine is passed to the post-order visitor as an _ExprVisitor object. 1870 1871 Args: 1872 expr: The IndexExpr being validated. 1873 expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the 1874 expression information. 1875 1876 Raises: 1877 ValueError if there is any problem in the IndexVars or dimensional values. 1878 """ 1879 # Objects of class Access can be shared by different expressions. Avoid 1880 # processing Access objects multiple times by skipping the processing if expr 1881 # is already in the dictionary. 1882 if expr in expr_to_info: 1883 return 1884 1885 if isinstance(expr, IndexVar): 1886 src_indices = expr, # A tuple with one element. 1887 dim_infos = _get_dummy_dim_info(), # A tuple with one element. 1888 elif isinstance(expr, Access): 1889 src_indices = expr.indices 1890 src_dims = tuple(expr.tensor.shape) 1891 if expr.tensor.format is None: 1892 # Treat each dimension of a dense tensor as DENSE for the purpose of 1893 # calculating temporary tensor storage format. 1894 mode_formats = tuple([ModeFormat.DENSE] * len(src_dims)) 1895 else: 1896 mode_formats = tuple(expr.tensor.format.format_pack.formats) 1897 assert len(src_dims) == len(mode_formats) 1898 dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)]) 1899 elif isinstance(expr, _UnaryExpr): 1900 a_info = expr_to_info[expr.a] 1901 index_to_dim_info = { 1902 i: d for i, d in zip(a_info.src_indices, a_info.dim_infos) 1903 } 1904 # Here we rely on the fact that dictionaries keep the insertion order for 1905 # keys and values. 1906 src_indices = tuple(index_to_dim_info.keys()) 1907 dim_infos = tuple(index_to_dim_info.values()) 1908 else: 1909 assert isinstance(expr, _BinaryExpr) 1910 a_info = expr_to_info[expr.a] 1911 index_to_dim_info = { 1912 i: d for i, d in zip(a_info.src_indices, a_info.dim_infos) 1913 } 1914 b_info = expr_to_info[expr.b] 1915 _validate_and_collect_dim_info(index_to_dim_info, b_info.src_indices, 1916 b_info.dim_infos, expr) 1917 # Here we rely on the fact that dictionaries keep the insertion order for 1918 # keys and values. 1919 src_indices = tuple(index_to_dim_info.keys()) 1920 dim_infos = tuple(index_to_dim_info.values()) 1921 1922 expr_to_info[expr] = _ExprInfo(src_indices, dim_infos) 1923 1924 1925def _mark_structured_op_root( 1926 expr: IndexExpr, 1927 reduce_index: IndexVar, 1928 expr_to_info: _ExprInfoDict, 1929) -> None: 1930 """Identifies the root expression for a structured op in the linalg dialect. 1931 1932 An linalg structured op can only perform reduction on the whole expression. 1933 For a TACO tensor algebra expression, the reduction on an IndexVar is done at 1934 the smallest expression that contains all the uses of the IndexVar. If such an 1935 expression is only part of the whole expression, we need to split this 1936 sub-expression tree out from its parent and implement the sub-expression as a 1937 structured op. 1938 1939 This routine identifies the root expression node for performing a reduction on 1940 the given IndexVar. If the reduction of the given IndexVar should be performed 1941 on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices 1942 1943 Args: 1944 expr: The root IndexExpr for the tensor algebra expression. 1945 reduce_index: The IndexVar which we want to find out the proper expression 1946 to perform a reduction. 1947 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. 1948 1949 Raises: 1950 ValueError: If the expression is not proper or not supported. 1951 """ 1952 expr_info = expr_to_info[expr] 1953 if isinstance(expr, Access): 1954 # Handle simple reduction expression in the format of A[i] = B[i, j]. 1955 if reduce_index in expr_info.src_indices: 1956 expr_info.reduce_indices.add(reduce_index) 1957 return 1958 elif isinstance(expr, IndexVar): 1959 # A[i] = B[i] + j is not allowed. 1960 raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.") 1961 1962 assert (isinstance(expr, _BinaryExpr)) 1963 a_info = expr_to_info[expr.a] 1964 b_info = expr_to_info[expr.b] 1965 1966 if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices: 1967 expr_info.reduce_indices.add(reduce_index) 1968 return 1969 1970 if reduce_index in a_info.src_indices: 1971 _mark_structured_op_root(expr.a, reduce_index, expr_to_info) 1972 elif reduce_index in b_info.src_indices: 1973 _mark_structured_op_root(expr.b, reduce_index, expr_to_info) 1974 else: 1975 assert False, "Unreachable path" 1976 1977 1978def _accumulate_reduce_indices( 1979 expr: IndexExpr, 1980 expr_to_info: _ExprInfoDict, 1981) -> None: 1982 """Propagates reduction indices from child expressions to parent expressions. 1983 1984 This routine is passed to the post-order visitor as an _ExprVisitor object. 1985 1986 Args: 1987 expr: The IndexExpr being visited. 1988 expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the 1989 expression information. 1990 """ 1991 assert expr in expr_to_info 1992 expr_info = expr_to_info[expr] 1993 1994 if isinstance(expr, _BinaryExpr): 1995 a_info = expr_to_info[expr.a] 1996 b_info = expr_to_info[expr.b] 1997 expr_info.acc_reduce_indices = ( 1998 a_info.acc_reduce_indices | b_info.acc_reduce_indices 1999 | expr_info.reduce_indices) 2000 elif isinstance(expr, _UnaryExpr): 2001 a_info = expr_to_info[expr.a] 2002 expr_info.acc_reduce_indices = ( 2003 a_info.acc_reduce_indices | expr_info.reduce_indices) 2004 elif isinstance(expr, IndexVar): 2005 # If an IndexVar is reducing itself, it means the IndexVar is outside the 2006 # iteration domain. This usage is now allowed and we should emit an error 2007 # before reaching here. 2008 assert not expr_info.reduce_indices 2009 else: 2010 assert isinstance(expr, Access) 2011 # Handle simple reduction expression in the format of A[i] = B[i, j]. 2012 expr_info.acc_reduce_indices = expr_info.reduce_indices 2013 2014 2015 2016def _gather_structured_op( 2017 expr: IndexExpr, 2018 expr_to_info: _ExprInfoDict, 2019 structop_roots: List[IndexExpr], 2020) -> None: 2021 """Adds structured op root expression information to structop_roots. 2022 2023 This routine is passed to the post-order visitor as an _ExprVisitor object. 2024 2025 Args: 2026 expr: The IndexExpr being visited. 2027 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. 2028 structop_roots: The resulting list of IndexExpr that are the roots for 2029 linalg structured ops. 2030 """ 2031 if not expr_to_info[expr].reduce_indices: 2032 return 2033 2034 # If the expression is the root for reducing some indices, collect the indices 2035 # and dimensions for the reduction result. 2036 dst_indices = [] 2037 dst_dims = [] 2038 mode_fmts = [] 2039 for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos): 2040 if i not in expr_to_info[expr].acc_reduce_indices: 2041 dst_indices.append(i) 2042 dst_dims.append(d.dim) 2043 mode_fmts.append(d.mode_format) 2044 2045 # Add the information to the dictionary. 2046 op_info = _StructOpInfo( 2047 tuple(dst_indices), 2048 tuple(dst_dims), 2049 expr.dtype(), 2050 f"temp{len(structop_roots)}", 2051 _make_format(mode_fmts), 2052 ) 2053 expr_to_info[expr].structop_info = op_info 2054 2055 # Add the expression to the list of structured op roots. 2056 structop_roots.append(expr) 2057 2058 2059def _is_structured_op_leaf( 2060 expr: IndexExpr, 2061 root: IndexExpr, 2062 expr_to_info: _ExprInfoDict, 2063 *unused_args, 2064) -> bool: 2065 """Returns true iff the expression is a leaf node for a structured op. 2066 2067 The root of a structured op is a leaf of its parent structured op that uses 2068 its result. An expression node is a leaf node for the current structured op if 2069 it is an Access node or the root for a structured op that is not the current 2070 structured op. 2071 2072 This routine is passed to the post-order visitor as a _SubtreeLeafChecker 2073 object. Because the post-order visitor pass the same parameters to both 2074 _SubtreeLeafChecker and _ExprVisitor, this routine may received unused 2075 parameters. 2076 2077 Args: 2078 expr: The IndexExpr being visited. 2079 root: The root of the current structured op. 2080 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. 2081 2082 Returns: 2083 True if the current IndexExpr is a leaf for the current structured op. 2084 """ 2085 return (expr != root and 2086 expr_to_info[expr].structop_info is not None) or isinstance( 2087 expr, Access) or isinstance(expr, IndexVar) 2088 2089 2090def _gather_structured_op_input( 2091 expr: IndexExpr, 2092 root: IndexExpr, 2093 expr_to_info: _ExprInfoDict, 2094 structop_inputs: List[IndexExpr], 2095) -> None: 2096 """Adds the IndexExpr to structop_inputs if it is an input. 2097 2098 If the current IndexExpr is an input for the current structured op, adds it to 2099 structop_inputs. The current IndexExpr is an input if it is an Access node or 2100 if it is the root for a structured op that is not the current structured op. 2101 2102 This routine is passed to the post-order visitor as an _ExprVisitor object. 2103 2104 Args: 2105 expr: The IndexExpr being visited. 2106 root: The root of the current structured op. 2107 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. 2108 structop_inputs: The resulting list of IndexExpr that provide input to the 2109 current structured op. 2110 """ 2111 if ((expr != root or isinstance(expr, Access)) and 2112 expr not in structop_inputs) and (isinstance(expr, Access) or 2113 (expr in expr_to_info and 2114 expr_to_info[expr].structop_info)): 2115 structop_inputs.append(expr) 2116 2117 2118def _emit_structured_op_input( 2119 expr: IndexExpr, 2120 expr_to_info: _ExprInfoDict, 2121 op_def: lang.LinalgOpDef, 2122) -> lang.OperandDef: 2123 """Emits OperandDef in the linalg dialect for the input IndexExpr. 2124 2125 Args: 2126 expr: The input IndexExpr for the current structured op. 2127 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. 2128 op_def: The linalg operation for the current structured op. 2129 2130 Returns: 2131 An OperandDef in the linalg dialect for the input IndexExpr. 2132 """ 2133 op_info = expr_to_info[expr].structop_info 2134 if op_info and not isinstance(expr, Access): 2135 # The input is a temporary tensor produced by another structured op. 2136 indices = op_info.dst_indices 2137 name = op_info.dst_name 2138 else: 2139 # The input is a user provided tensor. 2140 assert isinstance(expr, Access) 2141 indices = expr.indices 2142 name = expr.tensor.name 2143 2144 dim_sym = _mlir_symbols_from_index_vars(indices) 2145 opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym) 2146 op_def.add_operand(name, opnd) 2147 return opnd 2148 2149 2150def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr": 2151 """Build a unary operation ceil. 2152 2153 Args: 2154 a: The operand, which could be any Python object from user inputs. 2155 op: An _UnaryOp object representing the operation. 2156 2157 Returns: 2158 A _UnaryExpr object representing the operation. 2159 2160 Raises: 2161 ValueError: If a is not an IndexExpr. 2162 """ 2163 if not isinstance(a, Access): 2164 raise ValueError(f"Expected an Access Operand: {a}") 2165 return a._build_unary_expr(op) 2166 2167 2168def ceil(a: Access) -> "_UnaryExpr": 2169 """Defines the operation ceil. 2170 2171 Args: 2172 a: The operand, which could be any Python object from user inputs. 2173 2174 Returns: 2175 A _UnaryExpr object representing the operation. 2176 2177 Raises: 2178 ValueError: If a is not an IndexExpr. 2179 """ 2180 return _check_and_build_unary(a, _op_ceil) 2181 2182 2183def floor(a: Access) -> "_UnaryExpr": 2184 """Defines the operation floor. 2185 2186 Args: 2187 a: The operand, which could be any Python object from user inputs. 2188 2189 Returns: 2190 A _UnaryExpr object representing the operation. 2191 2192 Raises: 2193 ValueError: If a is not an IndexExpr. 2194 """ 2195 return _check_and_build_unary(a, _op_floor) 2196