1f13893f6SStella Laurenzo# RUN: %PYTHON %s | FileCheck %s 2f13893f6SStella Laurenzo 3f13893f6SStella Laurenzofrom mlir.ir import * 4f38633d1SStella Laurenzofrom mlir.dialects import sparse_tensor as st 5f13893f6SStella Laurenzo 6f13893f6SStella Laurenzodef run(f): 7f13893f6SStella Laurenzo print("\nTEST:", f.__name__) 8f13893f6SStella Laurenzo f() 9f13893f6SStella Laurenzo return f 10f13893f6SStella Laurenzo 11f13893f6SStella Laurenzo 12f13893f6SStella Laurenzo# CHECK-LABEL: TEST: testEncodingAttr1D 13f13893f6SStella Laurenzo@run 14f13893f6SStella Laurenzodef testEncodingAttr1D(): 15f13893f6SStella Laurenzo with Context() as ctx: 16f13893f6SStella Laurenzo parsed = Attribute.parse( 17f13893f6SStella Laurenzo '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], ' 18f13893f6SStella Laurenzo 'pointerBitWidth = 16, indexBitWidth = 32 }>') 19f13893f6SStella Laurenzo print(parsed) 20f13893f6SStella Laurenzo 21f13893f6SStella Laurenzo casted = st.EncodingAttr(parsed) 22f13893f6SStella Laurenzo # CHECK: equal: True 23f13893f6SStella Laurenzo print(f"equal: {casted == parsed}") 24f13893f6SStella Laurenzo 25f13893f6SStella Laurenzo # CHECK: dim_level_types: [<DimLevelType.compressed: 1>] 26f13893f6SStella Laurenzo print(f"dim_level_types: {casted.dim_level_types}") 27f13893f6SStella Laurenzo # CHECK: dim_ordering: None 28f13893f6SStella Laurenzo # Note that for 1D, the ordering is None, which exercises several special 29f13893f6SStella Laurenzo # cases. 30f13893f6SStella Laurenzo print(f"dim_ordering: {casted.dim_ordering}") 31f13893f6SStella Laurenzo # CHECK: pointer_bit_width: 16 32f13893f6SStella Laurenzo print(f"pointer_bit_width: {casted.pointer_bit_width}") 33f13893f6SStella Laurenzo # CHECK: index_bit_width: 32 34f13893f6SStella Laurenzo print(f"index_bit_width: {casted.index_bit_width}") 35f13893f6SStella Laurenzo 36f13893f6SStella Laurenzo created = st.EncodingAttr.get(casted.dim_level_types, None, 16, 32) 37f13893f6SStella Laurenzo print(created) 38f13893f6SStella Laurenzo # CHECK: created_equal: True 39f13893f6SStella Laurenzo print(f"created_equal: {created == casted}") 40f13893f6SStella Laurenzo 41f13893f6SStella Laurenzo # Verify that the factory creates an instance of the proper type. 42f13893f6SStella Laurenzo # CHECK: is_proper_instance: True 43f13893f6SStella Laurenzo print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}") 44f13893f6SStella Laurenzo # CHECK: created_pointer_bit_width: 16 45f13893f6SStella Laurenzo print(f"created_pointer_bit_width: {created.pointer_bit_width}") 46f13893f6SStella Laurenzo 47f13893f6SStella Laurenzo 48f13893f6SStella Laurenzo# CHECK-LABEL: TEST: testEncodingAttr2D 49f13893f6SStella Laurenzo@run 50f13893f6SStella Laurenzodef testEncodingAttr2D(): 51f13893f6SStella Laurenzo with Context() as ctx: 52f13893f6SStella Laurenzo parsed = Attribute.parse( 53f13893f6SStella Laurenzo '#sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], ' 54f13893f6SStella Laurenzo 'dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, ' 55f13893f6SStella Laurenzo 'pointerBitWidth = 16, indexBitWidth = 32 }>') 56f13893f6SStella Laurenzo print(parsed) 57f13893f6SStella Laurenzo 58f13893f6SStella Laurenzo casted = st.EncodingAttr(parsed) 59f13893f6SStella Laurenzo # CHECK: equal: True 60f13893f6SStella Laurenzo print(f"equal: {casted == parsed}") 61f13893f6SStella Laurenzo 62f13893f6SStella Laurenzo # CHECK: dim_level_types: [<DimLevelType.dense: 0>, <DimLevelType.compressed: 1>] 63f13893f6SStella Laurenzo print(f"dim_level_types: {casted.dim_level_types}") 64f13893f6SStella Laurenzo # CHECK: dim_ordering: (d0, d1) -> (d0, d1) 65f13893f6SStella Laurenzo print(f"dim_ordering: {casted.dim_ordering}") 66f13893f6SStella Laurenzo # CHECK: pointer_bit_width: 16 67f13893f6SStella Laurenzo print(f"pointer_bit_width: {casted.pointer_bit_width}") 68f13893f6SStella Laurenzo # CHECK: index_bit_width: 32 69f13893f6SStella Laurenzo print(f"index_bit_width: {casted.index_bit_width}") 70f13893f6SStella Laurenzo 71f13893f6SStella Laurenzo created = st.EncodingAttr.get(casted.dim_level_types, casted.dim_ordering, 72f13893f6SStella Laurenzo 16, 32) 73f13893f6SStella Laurenzo print(created) 74f13893f6SStella Laurenzo # CHECK: created_equal: True 75f13893f6SStella Laurenzo print(f"created_equal: {created == casted}") 76*a2c8aebdSStella Laurenzo 77*a2c8aebdSStella Laurenzo 78*a2c8aebdSStella Laurenzo# CHECK-LABEL: TEST: testEncodingAttrOnTensor 79*a2c8aebdSStella Laurenzo@run 80*a2c8aebdSStella Laurenzodef testEncodingAttrOnTensor(): 81*a2c8aebdSStella Laurenzo with Context() as ctx, Location.unknown(): 82*a2c8aebdSStella Laurenzo encoding = st.EncodingAttr(Attribute.parse( 83*a2c8aebdSStella Laurenzo '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], ' 84*a2c8aebdSStella Laurenzo 'pointerBitWidth = 16, indexBitWidth = 32 }>')) 85*a2c8aebdSStella Laurenzo tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding) 86*a2c8aebdSStella Laurenzo # CHECK: tensor<1024xf32, #sparse_tensor 87*a2c8aebdSStella Laurenzo print(tt) 88*a2c8aebdSStella Laurenzo # CHECK: #sparse_tensor.encoding 89*a2c8aebdSStella Laurenzo print(tt.encoding) 90*a2c8aebdSStella Laurenzo assert tt.encoding == encoding 91