1f13893f6SStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
2f13893f6SStella Laurenzo
3f13893f6SStella Laurenzofrom mlir.ir import *
4f38633d1SStella Laurenzofrom mlir.dialects import sparse_tensor as st
5f13893f6SStella Laurenzo
6f13893f6SStella Laurenzodef run(f):
7f13893f6SStella Laurenzo  print("\nTEST:", f.__name__)
8f13893f6SStella Laurenzo  f()
9f13893f6SStella Laurenzo  return f
10f13893f6SStella Laurenzo
11f13893f6SStella Laurenzo
12f13893f6SStella Laurenzo# CHECK-LABEL: TEST: testEncodingAttr1D
13f13893f6SStella Laurenzo@run
14f13893f6SStella Laurenzodef testEncodingAttr1D():
15f13893f6SStella Laurenzo  with Context() as ctx:
16f13893f6SStella Laurenzo    parsed = Attribute.parse(
17f13893f6SStella Laurenzo      '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
18f13893f6SStella Laurenzo      'pointerBitWidth = 16, indexBitWidth = 32 }>')
19f13893f6SStella Laurenzo    print(parsed)
20f13893f6SStella Laurenzo
21f13893f6SStella Laurenzo    casted = st.EncodingAttr(parsed)
22f13893f6SStella Laurenzo    # CHECK: equal: True
23f13893f6SStella Laurenzo    print(f"equal: {casted == parsed}")
24f13893f6SStella Laurenzo
25f13893f6SStella Laurenzo    # CHECK: dim_level_types: [<DimLevelType.compressed: 1>]
26f13893f6SStella Laurenzo    print(f"dim_level_types: {casted.dim_level_types}")
27f13893f6SStella Laurenzo    # CHECK: dim_ordering: None
28f13893f6SStella Laurenzo    # Note that for 1D, the ordering is None, which exercises several special
29f13893f6SStella Laurenzo    # cases.
30f13893f6SStella Laurenzo    print(f"dim_ordering: {casted.dim_ordering}")
31f13893f6SStella Laurenzo    # CHECK: pointer_bit_width: 16
32f13893f6SStella Laurenzo    print(f"pointer_bit_width: {casted.pointer_bit_width}")
33f13893f6SStella Laurenzo    # CHECK: index_bit_width: 32
34f13893f6SStella Laurenzo    print(f"index_bit_width: {casted.index_bit_width}")
35f13893f6SStella Laurenzo
36f13893f6SStella Laurenzo    created = st.EncodingAttr.get(casted.dim_level_types, None, 16, 32)
37f13893f6SStella Laurenzo    print(created)
38f13893f6SStella Laurenzo    # CHECK: created_equal: True
39f13893f6SStella Laurenzo    print(f"created_equal: {created == casted}")
40f13893f6SStella Laurenzo
41f13893f6SStella Laurenzo    # Verify that the factory creates an instance of the proper type.
42f13893f6SStella Laurenzo    # CHECK: is_proper_instance: True
43f13893f6SStella Laurenzo    print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
44f13893f6SStella Laurenzo    # CHECK: created_pointer_bit_width: 16
45f13893f6SStella Laurenzo    print(f"created_pointer_bit_width: {created.pointer_bit_width}")
46f13893f6SStella Laurenzo
47f13893f6SStella Laurenzo
48f13893f6SStella Laurenzo# CHECK-LABEL: TEST: testEncodingAttr2D
49f13893f6SStella Laurenzo@run
50f13893f6SStella Laurenzodef testEncodingAttr2D():
51f13893f6SStella Laurenzo  with Context() as ctx:
52f13893f6SStella Laurenzo    parsed = Attribute.parse(
53f13893f6SStella Laurenzo      '#sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], '
54f13893f6SStella Laurenzo      'dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, '
55f13893f6SStella Laurenzo      'pointerBitWidth = 16, indexBitWidth = 32 }>')
56f13893f6SStella Laurenzo    print(parsed)
57f13893f6SStella Laurenzo
58f13893f6SStella Laurenzo    casted = st.EncodingAttr(parsed)
59f13893f6SStella Laurenzo    # CHECK: equal: True
60f13893f6SStella Laurenzo    print(f"equal: {casted == parsed}")
61f13893f6SStella Laurenzo
62f13893f6SStella Laurenzo    # CHECK: dim_level_types: [<DimLevelType.dense: 0>, <DimLevelType.compressed: 1>]
63f13893f6SStella Laurenzo    print(f"dim_level_types: {casted.dim_level_types}")
64f13893f6SStella Laurenzo    # CHECK: dim_ordering: (d0, d1) -> (d0, d1)
65f13893f6SStella Laurenzo    print(f"dim_ordering: {casted.dim_ordering}")
66f13893f6SStella Laurenzo    # CHECK: pointer_bit_width: 16
67f13893f6SStella Laurenzo    print(f"pointer_bit_width: {casted.pointer_bit_width}")
68f13893f6SStella Laurenzo    # CHECK: index_bit_width: 32
69f13893f6SStella Laurenzo    print(f"index_bit_width: {casted.index_bit_width}")
70f13893f6SStella Laurenzo
71f13893f6SStella Laurenzo    created = st.EncodingAttr.get(casted.dim_level_types, casted.dim_ordering,
72f13893f6SStella Laurenzo        16, 32)
73f13893f6SStella Laurenzo    print(created)
74f13893f6SStella Laurenzo    # CHECK: created_equal: True
75f13893f6SStella Laurenzo    print(f"created_equal: {created == casted}")
76*a2c8aebdSStella Laurenzo
77*a2c8aebdSStella Laurenzo
78*a2c8aebdSStella Laurenzo# CHECK-LABEL: TEST: testEncodingAttrOnTensor
79*a2c8aebdSStella Laurenzo@run
80*a2c8aebdSStella Laurenzodef testEncodingAttrOnTensor():
81*a2c8aebdSStella Laurenzo  with Context() as ctx, Location.unknown():
82*a2c8aebdSStella Laurenzo    encoding = st.EncodingAttr(Attribute.parse(
83*a2c8aebdSStella Laurenzo      '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
84*a2c8aebdSStella Laurenzo      'pointerBitWidth = 16, indexBitWidth = 32 }>'))
85*a2c8aebdSStella Laurenzo    tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
86*a2c8aebdSStella Laurenzo    # CHECK: tensor<1024xf32, #sparse_tensor
87*a2c8aebdSStella Laurenzo    print(tt)
88*a2c8aebdSStella Laurenzo    # CHECK: #sparse_tensor.encoding
89*a2c8aebdSStella Laurenzo    print(tt.encoding)
90*a2c8aebdSStella Laurenzo    assert tt.encoding == encoding
91