1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import quant 5 6 7def run(f): 8 print("\nTEST:", f.__name__) 9 f() 10 return f 11 12 13# CHECK-LABEL: TEST: test_type_hierarchy 14@run 15def test_type_hierarchy(): 16 with Context(): 17 i8 = IntegerType.get_signless(8) 18 any = Type.parse("!quant.any<i8<-8:7>:f32>") 19 uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") 20 per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") 21 calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") 22 23 assert not quant.QuantizedType.isinstance(i8) 24 assert quant.QuantizedType.isinstance(any) 25 assert quant.QuantizedType.isinstance(uniform) 26 assert quant.QuantizedType.isinstance(per_axis) 27 assert quant.QuantizedType.isinstance(calibrated) 28 29 assert quant.AnyQuantizedType.isinstance(any) 30 assert quant.UniformQuantizedType.isinstance(uniform) 31 assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) 32 assert quant.CalibratedQuantizedType.isinstance(calibrated) 33 34 assert not quant.AnyQuantizedType.isinstance(uniform) 35 assert not quant.UniformQuantizedType.isinstance(per_axis) 36 37 38# CHECK-LABEL: TEST: test_any_quantized_type 39@run 40def test_any_quantized_type(): 41 with Context(): 42 i8 = IntegerType.get_signless(8) 43 f32 = F32Type.get() 44 any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32, 45 -8, 7) 46 47 # CHECK: flags: 1 48 print(f"flags: {any.flags}") 49 # CHECK: signed: True 50 print(f"signed: {any.is_signed}") 51 # CHECK: storage type: i8 52 print(f"storage type: {any.storage_type}") 53 # CHECK: expressed type: f32 54 print(f"expressed type: {any.expressed_type}") 55 # CHECK: storage min: -8 56 print(f"storage min: {any.storage_type_min}") 57 # CHECK: storage max: 7 58 print(f"storage max: {any.storage_type_max}") 59 # CHECK: storage width: 8 60 print(f"storage width: {any.storage_type_integral_width}") 61 # CHECK: quantized element type: !quant.any<i8<-8:7>:f32> 62 print(f"quantized element type: {any.quantized_element_type}") 63 # CHECK: !quant.any<i8<-8:7>:f32> 64 print(any) 65 assert any == Type.parse("!quant.any<i8<-8:7>:f32>") 66 67 68# CHECK-LABEL: TEST: test_uniform_type 69@run 70def test_uniform_type(): 71 with Context(): 72 i8 = IntegerType.get_signless(8) 73 f32 = F32Type.get() 74 uniform = quant.UniformQuantizedType.get( 75 quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7) 76 77 # CHECK: scale: 0.99872 78 print(f"scale: {uniform.scale}") 79 # CHECK: zero point: 127 80 print(f"zero point: {uniform.zero_point}") 81 # CHECK: fixed point: False 82 print(f"fixed point: {uniform.is_fixed_point}") 83 # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127> 84 print(uniform) 85 assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") 86 87 88# CHECK-LABEL: TEST: test_uniform_per_axis_type 89@run 90def test_uniform_per_axis_type(): 91 with Context(): 92 i8 = IntegerType.get_signless(8) 93 f32 = F32Type.get() 94 per_axis = quant.UniformQuantizedPerAxisType.get( 95 quant.QuantizedType.FLAG_SIGNED, 96 i8, 97 f32, [200, 0.99872], [0, 120], 98 quantized_dimension=1, 99 storage_type_min=quant.QuantizedType.default_minimum_for_integer( 100 is_signed=True, integral_width=8), 101 storage_type_max=quant.QuantizedType.default_maximum_for_integer( 102 is_signed=True, integral_width=8)) 103 104 # CHECK: scales: None 105 print(f"scales: {per_axis.scales}") 106 # CHECK: zero_points: None 107 print(f"zero_points: {per_axis.zero_points}") 108 # CHECK: quantized dim: 1 109 print(f"quantized dim: {per_axis.quantized_dimension}") 110 # CHECK: fixed point: False 111 print(f"fixed point: {per_axis.is_fixed_point}") 112 # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}> 113 print(per_axis) 114 assert per_axis == Type.parse( 115 "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") 116 117 118# CHECK-LABEL: TEST: test_calibrated_type 119@run 120def test_calibrated_type(): 121 with Context(): 122 f32 = F32Type.get() 123 calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321) 124 125 # CHECK: min: -0.998 126 print(f"min: {calibrated.min}") 127 # CHECK: max: 1.2321 128 print(f"max: {calibrated.max}") 129 # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>> 130 print(calibrated) 131 assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") 132