1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import quant
5
6
7def run(f):
8  print("\nTEST:", f.__name__)
9  f()
10  return f
11
12
13# CHECK-LABEL: TEST: test_type_hierarchy
14@run
15def test_type_hierarchy():
16  with Context():
17    i8 = IntegerType.get_signless(8)
18    any = Type.parse("!quant.any<i8<-8:7>:f32>")
19    uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
20    per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
21    calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
22
23    assert not quant.QuantizedType.isinstance(i8)
24    assert quant.QuantizedType.isinstance(any)
25    assert quant.QuantizedType.isinstance(uniform)
26    assert quant.QuantizedType.isinstance(per_axis)
27    assert quant.QuantizedType.isinstance(calibrated)
28
29    assert quant.AnyQuantizedType.isinstance(any)
30    assert quant.UniformQuantizedType.isinstance(uniform)
31    assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
32    assert quant.CalibratedQuantizedType.isinstance(calibrated)
33
34    assert not quant.AnyQuantizedType.isinstance(uniform)
35    assert not quant.UniformQuantizedType.isinstance(per_axis)
36
37
38# CHECK-LABEL: TEST: test_any_quantized_type
39@run
40def test_any_quantized_type():
41  with Context():
42    i8 = IntegerType.get_signless(8)
43    f32 = F32Type.get()
44    any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32,
45                                     -8, 7)
46
47    # CHECK: flags: 1
48    print(f"flags: {any.flags}")
49    # CHECK: signed: True
50    print(f"signed: {any.is_signed}")
51    # CHECK: storage type: i8
52    print(f"storage type: {any.storage_type}")
53    # CHECK: expressed type: f32
54    print(f"expressed type: {any.expressed_type}")
55    # CHECK: storage min: -8
56    print(f"storage min: {any.storage_type_min}")
57    # CHECK: storage max: 7
58    print(f"storage max: {any.storage_type_max}")
59    # CHECK: storage width: 8
60    print(f"storage width: {any.storage_type_integral_width}")
61    # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
62    print(f"quantized element type: {any.quantized_element_type}")
63    # CHECK: !quant.any<i8<-8:7>:f32>
64    print(any)
65    assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
66
67
68# CHECK-LABEL: TEST: test_uniform_type
69@run
70def test_uniform_type():
71  with Context():
72    i8 = IntegerType.get_signless(8)
73    f32 = F32Type.get()
74    uniform = quant.UniformQuantizedType.get(
75        quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7)
76
77    # CHECK: scale: 0.99872
78    print(f"scale: {uniform.scale}")
79    # CHECK: zero point: 127
80    print(f"zero point: {uniform.zero_point}")
81    # CHECK: fixed point: False
82    print(f"fixed point: {uniform.is_fixed_point}")
83    # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
84    print(uniform)
85    assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
86
87
88# CHECK-LABEL: TEST: test_uniform_per_axis_type
89@run
90def test_uniform_per_axis_type():
91  with Context():
92    i8 = IntegerType.get_signless(8)
93    f32 = F32Type.get()
94    per_axis = quant.UniformQuantizedPerAxisType.get(
95        quant.QuantizedType.FLAG_SIGNED,
96        i8,
97        f32, [200, 0.99872], [0, 120],
98        quantized_dimension=1,
99        storage_type_min=quant.QuantizedType.default_minimum_for_integer(
100            is_signed=True, integral_width=8),
101        storage_type_max=quant.QuantizedType.default_maximum_for_integer(
102            is_signed=True, integral_width=8))
103
104    # CHECK: scales: None
105    print(f"scales: {per_axis.scales}")
106    # CHECK: zero_points: None
107    print(f"zero_points: {per_axis.zero_points}")
108    # CHECK: quantized dim: 1
109    print(f"quantized dim: {per_axis.quantized_dimension}")
110    # CHECK: fixed point: False
111    print(f"fixed point: {per_axis.is_fixed_point}")
112    # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
113    print(per_axis)
114    assert per_axis == Type.parse(
115        "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
116
117
118# CHECK-LABEL: TEST: test_calibrated_type
119@run
120def test_calibrated_type():
121  with Context():
122    f32 = F32Type.get()
123    calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
124
125    # CHECK: min: -0.998
126    print(f"min: {calibrated.min}")
127    # CHECK: max: 1.2321
128    print(f"max: {calibrated.max}")
129    # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
130    print(calibrated)
131    assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
132