1*66d4090dSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s
2*66d4090dSAlex Zinenko
3*66d4090dSAlex Zinenkofrom mlir.ir import *
4*66d4090dSAlex Zinenkofrom mlir.dialects import quant
5*66d4090dSAlex Zinenko
6*66d4090dSAlex Zinenko
7*66d4090dSAlex Zinenkodef run(f):
8*66d4090dSAlex Zinenko  print("\nTEST:", f.__name__)
9*66d4090dSAlex Zinenko  f()
10*66d4090dSAlex Zinenko  return f
11*66d4090dSAlex Zinenko
12*66d4090dSAlex Zinenko
13*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_type_hierarchy
14*66d4090dSAlex Zinenko@run
15*66d4090dSAlex Zinenkodef test_type_hierarchy():
16*66d4090dSAlex Zinenko  with Context():
17*66d4090dSAlex Zinenko    i8 = IntegerType.get_signless(8)
18*66d4090dSAlex Zinenko    any = Type.parse("!quant.any<i8<-8:7>:f32>")
19*66d4090dSAlex Zinenko    uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
20*66d4090dSAlex Zinenko    per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
21*66d4090dSAlex Zinenko    calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
22*66d4090dSAlex Zinenko
23*66d4090dSAlex Zinenko    assert not quant.QuantizedType.isinstance(i8)
24*66d4090dSAlex Zinenko    assert quant.QuantizedType.isinstance(any)
25*66d4090dSAlex Zinenko    assert quant.QuantizedType.isinstance(uniform)
26*66d4090dSAlex Zinenko    assert quant.QuantizedType.isinstance(per_axis)
27*66d4090dSAlex Zinenko    assert quant.QuantizedType.isinstance(calibrated)
28*66d4090dSAlex Zinenko
29*66d4090dSAlex Zinenko    assert quant.AnyQuantizedType.isinstance(any)
30*66d4090dSAlex Zinenko    assert quant.UniformQuantizedType.isinstance(uniform)
31*66d4090dSAlex Zinenko    assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
32*66d4090dSAlex Zinenko    assert quant.CalibratedQuantizedType.isinstance(calibrated)
33*66d4090dSAlex Zinenko
34*66d4090dSAlex Zinenko    assert not quant.AnyQuantizedType.isinstance(uniform)
35*66d4090dSAlex Zinenko    assert not quant.UniformQuantizedType.isinstance(per_axis)
36*66d4090dSAlex Zinenko
37*66d4090dSAlex Zinenko
38*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_any_quantized_type
39*66d4090dSAlex Zinenko@run
40*66d4090dSAlex Zinenkodef test_any_quantized_type():
41*66d4090dSAlex Zinenko  with Context():
42*66d4090dSAlex Zinenko    i8 = IntegerType.get_signless(8)
43*66d4090dSAlex Zinenko    f32 = F32Type.get()
44*66d4090dSAlex Zinenko    any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32,
45*66d4090dSAlex Zinenko                                     -8, 7)
46*66d4090dSAlex Zinenko
47*66d4090dSAlex Zinenko    # CHECK: flags: 1
48*66d4090dSAlex Zinenko    print(f"flags: {any.flags}")
49*66d4090dSAlex Zinenko    # CHECK: signed: True
50*66d4090dSAlex Zinenko    print(f"signed: {any.is_signed}")
51*66d4090dSAlex Zinenko    # CHECK: storage type: i8
52*66d4090dSAlex Zinenko    print(f"storage type: {any.storage_type}")
53*66d4090dSAlex Zinenko    # CHECK: expressed type: f32
54*66d4090dSAlex Zinenko    print(f"expressed type: {any.expressed_type}")
55*66d4090dSAlex Zinenko    # CHECK: storage min: -8
56*66d4090dSAlex Zinenko    print(f"storage min: {any.storage_type_min}")
57*66d4090dSAlex Zinenko    # CHECK: storage max: 7
58*66d4090dSAlex Zinenko    print(f"storage max: {any.storage_type_max}")
59*66d4090dSAlex Zinenko    # CHECK: storage width: 8
60*66d4090dSAlex Zinenko    print(f"storage width: {any.storage_type_integral_width}")
61*66d4090dSAlex Zinenko    # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
62*66d4090dSAlex Zinenko    print(f"quantized element type: {any.quantized_element_type}")
63*66d4090dSAlex Zinenko    # CHECK: !quant.any<i8<-8:7>:f32>
64*66d4090dSAlex Zinenko    print(any)
65*66d4090dSAlex Zinenko    assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
66*66d4090dSAlex Zinenko
67*66d4090dSAlex Zinenko
68*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_uniform_type
69*66d4090dSAlex Zinenko@run
70*66d4090dSAlex Zinenkodef test_uniform_type():
71*66d4090dSAlex Zinenko  with Context():
72*66d4090dSAlex Zinenko    i8 = IntegerType.get_signless(8)
73*66d4090dSAlex Zinenko    f32 = F32Type.get()
74*66d4090dSAlex Zinenko    uniform = quant.UniformQuantizedType.get(
75*66d4090dSAlex Zinenko        quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7)
76*66d4090dSAlex Zinenko
77*66d4090dSAlex Zinenko    # CHECK: scale: 0.99872
78*66d4090dSAlex Zinenko    print(f"scale: {uniform.scale}")
79*66d4090dSAlex Zinenko    # CHECK: zero point: 127
80*66d4090dSAlex Zinenko    print(f"zero point: {uniform.zero_point}")
81*66d4090dSAlex Zinenko    # CHECK: fixed point: False
82*66d4090dSAlex Zinenko    print(f"fixed point: {uniform.is_fixed_point}")
83*66d4090dSAlex Zinenko    # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
84*66d4090dSAlex Zinenko    print(uniform)
85*66d4090dSAlex Zinenko    assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
86*66d4090dSAlex Zinenko
87*66d4090dSAlex Zinenko
88*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_uniform_per_axis_type
89*66d4090dSAlex Zinenko@run
90*66d4090dSAlex Zinenkodef test_uniform_per_axis_type():
91*66d4090dSAlex Zinenko  with Context():
92*66d4090dSAlex Zinenko    i8 = IntegerType.get_signless(8)
93*66d4090dSAlex Zinenko    f32 = F32Type.get()
94*66d4090dSAlex Zinenko    per_axis = quant.UniformQuantizedPerAxisType.get(
95*66d4090dSAlex Zinenko        quant.QuantizedType.FLAG_SIGNED,
96*66d4090dSAlex Zinenko        i8,
97*66d4090dSAlex Zinenko        f32, [200, 0.99872], [0, 120],
98*66d4090dSAlex Zinenko        quantized_dimension=1,
99*66d4090dSAlex Zinenko        storage_type_min=quant.QuantizedType.default_minimum_for_integer(
100*66d4090dSAlex Zinenko            is_signed=True, integral_width=8),
101*66d4090dSAlex Zinenko        storage_type_max=quant.QuantizedType.default_maximum_for_integer(
102*66d4090dSAlex Zinenko            is_signed=True, integral_width=8))
103*66d4090dSAlex Zinenko
104*66d4090dSAlex Zinenko    # CHECK: scales: None
105*66d4090dSAlex Zinenko    print(f"scales: {per_axis.scales}")
106*66d4090dSAlex Zinenko    # CHECK: zero_points: None
107*66d4090dSAlex Zinenko    print(f"zero_points: {per_axis.zero_points}")
108*66d4090dSAlex Zinenko    # CHECK: quantized dim: 1
109*66d4090dSAlex Zinenko    print(f"quantized dim: {per_axis.quantized_dimension}")
110*66d4090dSAlex Zinenko    # CHECK: fixed point: False
111*66d4090dSAlex Zinenko    print(f"fixed point: {per_axis.is_fixed_point}")
112*66d4090dSAlex Zinenko    # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
113*66d4090dSAlex Zinenko    print(per_axis)
114*66d4090dSAlex Zinenko    assert per_axis == Type.parse(
115*66d4090dSAlex Zinenko        "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
116*66d4090dSAlex Zinenko
117*66d4090dSAlex Zinenko
118*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_calibrated_type
119*66d4090dSAlex Zinenko@run
120*66d4090dSAlex Zinenkodef test_calibrated_type():
121*66d4090dSAlex Zinenko  with Context():
122*66d4090dSAlex Zinenko    f32 = F32Type.get()
123*66d4090dSAlex Zinenko    calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
124*66d4090dSAlex Zinenko
125*66d4090dSAlex Zinenko    # CHECK: min: -0.998
126*66d4090dSAlex Zinenko    print(f"min: {calibrated.min}")
127*66d4090dSAlex Zinenko    # CHECK: max: 1.2321
128*66d4090dSAlex Zinenko    print(f"max: {calibrated.max}")
129*66d4090dSAlex Zinenko    # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
130*66d4090dSAlex Zinenko    print(calibrated)
131*66d4090dSAlex Zinenko    assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
132