1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import sparse_tensor as st
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  return f
10
11
12# CHECK-LABEL: TEST: testEncodingAttr1D
13@run
14def testEncodingAttr1D():
15  with Context() as ctx:
16    parsed = Attribute.parse(
17      '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
18      'pointerBitWidth = 16, indexBitWidth = 32 }>')
19    print(parsed)
20
21    casted = st.EncodingAttr(parsed)
22    # CHECK: equal: True
23    print(f"equal: {casted == parsed}")
24
25    # CHECK: dim_level_types: [<DimLevelType.compressed: 1>]
26    print(f"dim_level_types: {casted.dim_level_types}")
27    # CHECK: dim_ordering: None
28    # Note that for 1D, the ordering is None, which exercises several special
29    # cases.
30    print(f"dim_ordering: {casted.dim_ordering}")
31    # CHECK: pointer_bit_width: 16
32    print(f"pointer_bit_width: {casted.pointer_bit_width}")
33    # CHECK: index_bit_width: 32
34    print(f"index_bit_width: {casted.index_bit_width}")
35
36    created = st.EncodingAttr.get(casted.dim_level_types, None, 16, 32)
37    print(created)
38    # CHECK: created_equal: True
39    print(f"created_equal: {created == casted}")
40
41    # Verify that the factory creates an instance of the proper type.
42    # CHECK: is_proper_instance: True
43    print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
44    # CHECK: created_pointer_bit_width: 16
45    print(f"created_pointer_bit_width: {created.pointer_bit_width}")
46
47
48# CHECK-LABEL: TEST: testEncodingAttr2D
49@run
50def testEncodingAttr2D():
51  with Context() as ctx:
52    parsed = Attribute.parse(
53      '#sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], '
54      'dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, '
55      'pointerBitWidth = 16, indexBitWidth = 32 }>')
56    print(parsed)
57
58    casted = st.EncodingAttr(parsed)
59    # CHECK: equal: True
60    print(f"equal: {casted == parsed}")
61
62    # CHECK: dim_level_types: [<DimLevelType.dense: 0>, <DimLevelType.compressed: 1>]
63    print(f"dim_level_types: {casted.dim_level_types}")
64    # CHECK: dim_ordering: (d0, d1) -> (d0, d1)
65    print(f"dim_ordering: {casted.dim_ordering}")
66    # CHECK: pointer_bit_width: 16
67    print(f"pointer_bit_width: {casted.pointer_bit_width}")
68    # CHECK: index_bit_width: 32
69    print(f"index_bit_width: {casted.index_bit_width}")
70
71    created = st.EncodingAttr.get(casted.dim_level_types, casted.dim_ordering,
72        16, 32)
73    print(created)
74    # CHECK: created_equal: True
75    print(f"created_equal: {created == casted}")
76
77
78# CHECK-LABEL: TEST: testEncodingAttrOnTensor
79@run
80def testEncodingAttrOnTensor():
81  with Context() as ctx, Location.unknown():
82    encoding = st.EncodingAttr(Attribute.parse(
83      '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
84      'pointerBitWidth = 16, indexBitWidth = 32 }>'))
85    tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
86    # CHECK: tensor<1024xf32, #sparse_tensor
87    print(tt)
88    # CHECK: #sparse_tensor.encoding
89    print(tt.encoding)
90    assert tt.encoding == encoding
91