1*66d4090dSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s 2*66d4090dSAlex Zinenko 3*66d4090dSAlex Zinenkofrom mlir.ir import * 4*66d4090dSAlex Zinenkofrom mlir.dialects import quant 5*66d4090dSAlex Zinenko 6*66d4090dSAlex Zinenko 7*66d4090dSAlex Zinenkodef run(f): 8*66d4090dSAlex Zinenko print("\nTEST:", f.__name__) 9*66d4090dSAlex Zinenko f() 10*66d4090dSAlex Zinenko return f 11*66d4090dSAlex Zinenko 12*66d4090dSAlex Zinenko 13*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_type_hierarchy 14*66d4090dSAlex Zinenko@run 15*66d4090dSAlex Zinenkodef test_type_hierarchy(): 16*66d4090dSAlex Zinenko with Context(): 17*66d4090dSAlex Zinenko i8 = IntegerType.get_signless(8) 18*66d4090dSAlex Zinenko any = Type.parse("!quant.any<i8<-8:7>:f32>") 19*66d4090dSAlex Zinenko uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") 20*66d4090dSAlex Zinenko per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") 21*66d4090dSAlex Zinenko calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") 22*66d4090dSAlex Zinenko 23*66d4090dSAlex Zinenko assert not quant.QuantizedType.isinstance(i8) 24*66d4090dSAlex Zinenko assert quant.QuantizedType.isinstance(any) 25*66d4090dSAlex Zinenko assert quant.QuantizedType.isinstance(uniform) 26*66d4090dSAlex Zinenko assert quant.QuantizedType.isinstance(per_axis) 27*66d4090dSAlex Zinenko assert quant.QuantizedType.isinstance(calibrated) 28*66d4090dSAlex Zinenko 29*66d4090dSAlex Zinenko assert quant.AnyQuantizedType.isinstance(any) 30*66d4090dSAlex Zinenko assert quant.UniformQuantizedType.isinstance(uniform) 31*66d4090dSAlex Zinenko assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) 32*66d4090dSAlex Zinenko assert quant.CalibratedQuantizedType.isinstance(calibrated) 33*66d4090dSAlex Zinenko 34*66d4090dSAlex Zinenko assert not quant.AnyQuantizedType.isinstance(uniform) 35*66d4090dSAlex Zinenko assert not quant.UniformQuantizedType.isinstance(per_axis) 36*66d4090dSAlex Zinenko 37*66d4090dSAlex Zinenko 38*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_any_quantized_type 39*66d4090dSAlex Zinenko@run 40*66d4090dSAlex Zinenkodef test_any_quantized_type(): 41*66d4090dSAlex Zinenko with Context(): 42*66d4090dSAlex Zinenko i8 = IntegerType.get_signless(8) 43*66d4090dSAlex Zinenko f32 = F32Type.get() 44*66d4090dSAlex Zinenko any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32, 45*66d4090dSAlex Zinenko -8, 7) 46*66d4090dSAlex Zinenko 47*66d4090dSAlex Zinenko # CHECK: flags: 1 48*66d4090dSAlex Zinenko print(f"flags: {any.flags}") 49*66d4090dSAlex Zinenko # CHECK: signed: True 50*66d4090dSAlex Zinenko print(f"signed: {any.is_signed}") 51*66d4090dSAlex Zinenko # CHECK: storage type: i8 52*66d4090dSAlex Zinenko print(f"storage type: {any.storage_type}") 53*66d4090dSAlex Zinenko # CHECK: expressed type: f32 54*66d4090dSAlex Zinenko print(f"expressed type: {any.expressed_type}") 55*66d4090dSAlex Zinenko # CHECK: storage min: -8 56*66d4090dSAlex Zinenko print(f"storage min: {any.storage_type_min}") 57*66d4090dSAlex Zinenko # CHECK: storage max: 7 58*66d4090dSAlex Zinenko print(f"storage max: {any.storage_type_max}") 59*66d4090dSAlex Zinenko # CHECK: storage width: 8 60*66d4090dSAlex Zinenko print(f"storage width: {any.storage_type_integral_width}") 61*66d4090dSAlex Zinenko # CHECK: quantized element type: !quant.any<i8<-8:7>:f32> 62*66d4090dSAlex Zinenko print(f"quantized element type: {any.quantized_element_type}") 63*66d4090dSAlex Zinenko # CHECK: !quant.any<i8<-8:7>:f32> 64*66d4090dSAlex Zinenko print(any) 65*66d4090dSAlex Zinenko assert any == Type.parse("!quant.any<i8<-8:7>:f32>") 66*66d4090dSAlex Zinenko 67*66d4090dSAlex Zinenko 68*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_uniform_type 69*66d4090dSAlex Zinenko@run 70*66d4090dSAlex Zinenkodef test_uniform_type(): 71*66d4090dSAlex Zinenko with Context(): 72*66d4090dSAlex Zinenko i8 = IntegerType.get_signless(8) 73*66d4090dSAlex Zinenko f32 = F32Type.get() 74*66d4090dSAlex Zinenko uniform = quant.UniformQuantizedType.get( 75*66d4090dSAlex Zinenko quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7) 76*66d4090dSAlex Zinenko 77*66d4090dSAlex Zinenko # CHECK: scale: 0.99872 78*66d4090dSAlex Zinenko print(f"scale: {uniform.scale}") 79*66d4090dSAlex Zinenko # CHECK: zero point: 127 80*66d4090dSAlex Zinenko print(f"zero point: {uniform.zero_point}") 81*66d4090dSAlex Zinenko # CHECK: fixed point: False 82*66d4090dSAlex Zinenko print(f"fixed point: {uniform.is_fixed_point}") 83*66d4090dSAlex Zinenko # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127> 84*66d4090dSAlex Zinenko print(uniform) 85*66d4090dSAlex Zinenko assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") 86*66d4090dSAlex Zinenko 87*66d4090dSAlex Zinenko 88*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_uniform_per_axis_type 89*66d4090dSAlex Zinenko@run 90*66d4090dSAlex Zinenkodef test_uniform_per_axis_type(): 91*66d4090dSAlex Zinenko with Context(): 92*66d4090dSAlex Zinenko i8 = IntegerType.get_signless(8) 93*66d4090dSAlex Zinenko f32 = F32Type.get() 94*66d4090dSAlex Zinenko per_axis = quant.UniformQuantizedPerAxisType.get( 95*66d4090dSAlex Zinenko quant.QuantizedType.FLAG_SIGNED, 96*66d4090dSAlex Zinenko i8, 97*66d4090dSAlex Zinenko f32, [200, 0.99872], [0, 120], 98*66d4090dSAlex Zinenko quantized_dimension=1, 99*66d4090dSAlex Zinenko storage_type_min=quant.QuantizedType.default_minimum_for_integer( 100*66d4090dSAlex Zinenko is_signed=True, integral_width=8), 101*66d4090dSAlex Zinenko storage_type_max=quant.QuantizedType.default_maximum_for_integer( 102*66d4090dSAlex Zinenko is_signed=True, integral_width=8)) 103*66d4090dSAlex Zinenko 104*66d4090dSAlex Zinenko # CHECK: scales: None 105*66d4090dSAlex Zinenko print(f"scales: {per_axis.scales}") 106*66d4090dSAlex Zinenko # CHECK: zero_points: None 107*66d4090dSAlex Zinenko print(f"zero_points: {per_axis.zero_points}") 108*66d4090dSAlex Zinenko # CHECK: quantized dim: 1 109*66d4090dSAlex Zinenko print(f"quantized dim: {per_axis.quantized_dimension}") 110*66d4090dSAlex Zinenko # CHECK: fixed point: False 111*66d4090dSAlex Zinenko print(f"fixed point: {per_axis.is_fixed_point}") 112*66d4090dSAlex Zinenko # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}> 113*66d4090dSAlex Zinenko print(per_axis) 114*66d4090dSAlex Zinenko assert per_axis == Type.parse( 115*66d4090dSAlex Zinenko "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") 116*66d4090dSAlex Zinenko 117*66d4090dSAlex Zinenko 118*66d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_calibrated_type 119*66d4090dSAlex Zinenko@run 120*66d4090dSAlex Zinenkodef test_calibrated_type(): 121*66d4090dSAlex Zinenko with Context(): 122*66d4090dSAlex Zinenko f32 = F32Type.get() 123*66d4090dSAlex Zinenko calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321) 124*66d4090dSAlex Zinenko 125*66d4090dSAlex Zinenko # CHECK: min: -0.998 126*66d4090dSAlex Zinenko print(f"min: {calibrated.min}") 127*66d4090dSAlex Zinenko # CHECK: max: 1.2321 128*66d4090dSAlex Zinenko print(f"max: {calibrated.max}") 129*66d4090dSAlex Zinenko # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>> 130*66d4090dSAlex Zinenko print(calibrated) 131*66d4090dSAlex Zinenko assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") 132