1 //===- quant.c - Test of Quant dialect C API ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 // Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 // RUN: mlir-capi-quant-test 2>&1 | FileCheck %s 11 12 #include "mlir-c/Dialect/Quant.h" 13 #include "mlir-c/BuiltinTypes.h" 14 #include "mlir-c/IR.h" 15 16 #include <assert.h> 17 #include <inttypes.h> 18 #include <stdio.h> 19 #include <stdlib.h> 20 21 // CHECK-LABEL: testTypeHierarchy 22 static void testTypeHierarchy(MlirContext ctx) { 23 fprintf(stderr, "testTypeHierarchy\n"); 24 25 MlirType i8 = mlirIntegerTypeGet(ctx, 8); 26 MlirType any = mlirTypeParseGet( 27 ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>")); 28 MlirType uniform = 29 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString( 30 "!quant.uniform<i8<-8:7>:f32, 0.99872:127>")); 31 MlirType perAxis = mlirTypeParseGet( 32 ctx, mlirStringRefCreateFromCString( 33 "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")); 34 MlirType calibrated = mlirTypeParseGet( 35 ctx, 36 mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>")); 37 38 // The parser itself is checked in C++ dialect tests. 39 assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType"); 40 assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType"); 41 assert(!mlirTypeIsNull(perAxis) && 42 "couldn't parse UniformQuantizedPerAxisType"); 43 assert(!mlirTypeIsNull(calibrated) && 44 "couldn't parse CalibratedQuantizedType"); 45 46 // CHECK: i8 isa QuantizedType: 0 47 fprintf(stderr, "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(i8)); 48 // CHECK: any isa QuantizedType: 1 49 fprintf(stderr, "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(any)); 50 // CHECK: uniform isa QuantizedType: 1 51 fprintf(stderr, "uniform isa QuantizedType: %d\n", 52 mlirTypeIsAQuantizedType(uniform)); 53 // CHECK: perAxis isa QuantizedType: 1 54 fprintf(stderr, "perAxis isa QuantizedType: %d\n", 55 mlirTypeIsAQuantizedType(perAxis)); 56 // CHECK: calibrated isa QuantizedType: 1 57 fprintf(stderr, "calibrated isa QuantizedType: %d\n", 58 mlirTypeIsAQuantizedType(calibrated)); 59 60 // CHECK: any isa AnyQuantizedType: 1 61 fprintf(stderr, "any isa AnyQuantizedType: %d\n", 62 mlirTypeIsAAnyQuantizedType(any)); 63 // CHECK: uniform isa UniformQuantizedType: 1 64 fprintf(stderr, "uniform isa UniformQuantizedType: %d\n", 65 mlirTypeIsAUniformQuantizedType(uniform)); 66 // CHECK: perAxis isa UniformQuantizedPerAxisType: 1 67 fprintf(stderr, "perAxis isa UniformQuantizedPerAxisType: %d\n", 68 mlirTypeIsAUniformQuantizedPerAxisType(perAxis)); 69 // CHECK: calibrated isa CalibratedQuantizedType: 1 70 fprintf(stderr, "calibrated isa CalibratedQuantizedType: %d\n", 71 mlirTypeIsACalibratedQuantizedType(calibrated)); 72 73 // CHECK: perAxis isa UniformQuantizedType: 0 74 fprintf(stderr, "perAxis isa UniformQuantizedType: %d\n", 75 mlirTypeIsAUniformQuantizedType(perAxis)); 76 // CHECK: uniform isa CalibratedQuantizedType: 0 77 fprintf(stderr, "uniform isa CalibratedQuantizedType: %d\n", 78 mlirTypeIsACalibratedQuantizedType(uniform)); 79 fprintf(stderr, "\n"); 80 } 81 82 // CHECK-LABEL: testAnyQuantizedType 83 void testAnyQuantizedType(MlirContext ctx) { 84 fprintf(stderr, "testAnyQuantizedType\n"); 85 86 MlirType anyParsed = mlirTypeParseGet( 87 ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>")); 88 89 MlirType i8 = mlirIntegerTypeGet(ctx, 8); 90 MlirType f32 = mlirF32TypeGet(ctx); 91 MlirType any = 92 mlirAnyQuantizedTypeGet(mlirQuantizedTypeGetSignedFlag(), i8, f32, -8, 7); 93 94 // CHECK: flags: 1 95 fprintf(stderr, "flags: %u\n", mlirQuantizedTypeGetFlags(any)); 96 // CHECK: signed: 1 97 fprintf(stderr, "signed: %u\n", mlirQuantizedTypeIsSigned(any)); 98 // CHECK: storage type: i8 99 fprintf(stderr, "storage type: "); 100 mlirTypeDump(mlirQuantizedTypeGetStorageType(any)); 101 fprintf(stderr, "\n"); 102 // CHECK: expressed type: f32 103 fprintf(stderr, "expressed type: "); 104 mlirTypeDump(mlirQuantizedTypeGetExpressedType(any)); 105 fprintf(stderr, "\n"); 106 // CHECK: storage min: -8 107 fprintf(stderr, "storage min: %" PRId64 "\n", 108 mlirQuantizedTypeGetStorageTypeMin(any)); 109 // CHECK: storage max: 7 110 fprintf(stderr, "storage max: %" PRId64 "\n", 111 mlirQuantizedTypeGetStorageTypeMax(any)); 112 // CHECK: storage width: 8 113 fprintf(stderr, "storage width: %u\n", 114 mlirQuantizedTypeGetStorageTypeIntegralWidth(any)); 115 // CHECK: quantized element type: !quant.any<i8<-8:7>:f32> 116 fprintf(stderr, "quantized element type: "); 117 mlirTypeDump(mlirQuantizedTypeGetQuantizedElementType(any)); 118 fprintf(stderr, "\n"); 119 120 // CHECK: equal: 1 121 fprintf(stderr, "equal: %d\n", mlirTypeEqual(anyParsed, any)); 122 // CHECK: !quant.any<i8<-8:7>:f32> 123 mlirTypeDump(any); 124 fprintf(stderr, "\n\n"); 125 } 126 127 // CHECK-LABEL: testUniformType 128 void testUniformType(MlirContext ctx) { 129 fprintf(stderr, "testUniformType\n"); 130 131 MlirType uniformParsed = 132 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString( 133 "!quant.uniform<i8<-8:7>:f32, 0.99872:127>")); 134 135 MlirType i8 = mlirIntegerTypeGet(ctx, 8); 136 MlirType f32 = mlirF32TypeGet(ctx); 137 MlirType uniform = mlirUniformQuantizedTypeGet( 138 mlirQuantizedTypeGetSignedFlag(), i8, f32, 0.99872, 127, -8, 7); 139 140 // CHECK: scale: 0.998720 141 fprintf(stderr, "scale: %lf\n", mlirUniformQuantizedTypeGetScale(uniform)); 142 // CHECK: zero point: 127 143 fprintf(stderr, "zero point: %" PRId64 "\n", 144 mlirUniformQuantizedTypeGetZeroPoint(uniform)); 145 // CHECK: fixed point: 0 146 fprintf(stderr, "fixed point: %d\n", 147 mlirUniformQuantizedTypeIsFixedPoint(uniform)); 148 149 // CHECK: equal: 1 150 fprintf(stderr, "equal: %d\n", mlirTypeEqual(uniform, uniformParsed)); 151 // CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127> 152 mlirTypeDump(uniform); 153 fprintf(stderr, "\n\n"); 154 } 155 156 // CHECK-LABEL: testUniformPerAxisType 157 void testUniformPerAxisType(MlirContext ctx) { 158 fprintf(stderr, "testUniformPerAxisType\n"); 159 160 MlirType perAxisParsed = mlirTypeParseGet( 161 ctx, mlirStringRefCreateFromCString( 162 "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")); 163 164 MlirType i8 = mlirIntegerTypeGet(ctx, 8); 165 MlirType f32 = mlirF32TypeGet(ctx); 166 double scales[] = {200.0, 0.99872}; 167 int64_t zeroPoints[] = {0, 120}; 168 MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet( 169 mlirQuantizedTypeGetSignedFlag(), i8, f32, 170 /*nDims=*/2, scales, zeroPoints, 171 /*quantizedDimension=*/1, 172 mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true, 173 /*integralWidth=*/8), 174 mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true, 175 /*integralWidth=*/8)); 176 177 // CHECK: num dims: 2 178 fprintf(stderr, "num dims: %" PRIdPTR "\n", 179 mlirUniformQuantizedPerAxisTypeGetNumDims(perAxis)); 180 // CHECK: scale 0: 200.000000 181 fprintf(stderr, "scale 0: %lf\n", 182 mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 0)); 183 // CHECK: scale 1: 0.998720 184 fprintf(stderr, "scale 1: %lf\n", 185 mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 1)); 186 // CHECK: zero point 0: 0 187 fprintf(stderr, "zero point 0: %" PRId64 "\n", 188 mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 0)); 189 // CHECK: zero point 1: 120 190 fprintf(stderr, "zero point 1: %" PRId64 "\n", 191 mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 1)); 192 // CHECK: quantized dim: 1 193 fprintf(stderr, "quantized dim: %" PRId32 "\n", 194 mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(perAxis)); 195 // CHECK: fixed point: 0 196 fprintf(stderr, "fixed point: %d\n", 197 mlirUniformQuantizedPerAxisTypeIsFixedPoint(perAxis)); 198 199 // CHECK: equal: 1 200 fprintf(stderr, "equal: %d\n", mlirTypeEqual(perAxis, perAxisParsed)); 201 // CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}> 202 mlirTypeDump(perAxis); 203 fprintf(stderr, "\n\n"); 204 } 205 206 // CHECK-LABEL: testCalibratedType 207 void testCalibratedType(MlirContext ctx) { 208 fprintf(stderr, "testCalibratedType\n"); 209 210 MlirType calibratedParsed = mlirTypeParseGet( 211 ctx, 212 mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>")); 213 214 MlirType f32 = mlirF32TypeGet(ctx); 215 MlirType calibrated = mlirCalibratedQuantizedTypeGet(f32, -0.998, 1.2321); 216 217 // CHECK: min: -0.998000 218 fprintf(stderr, "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(calibrated)); 219 // CHECK: max: 1.232100 220 fprintf(stderr, "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(calibrated)); 221 222 // CHECK: equal: 1 223 fprintf(stderr, "equal: %d\n", mlirTypeEqual(calibrated, calibratedParsed)); 224 // CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>> 225 mlirTypeDump(calibrated); 226 fprintf(stderr, "\n\n"); 227 } 228 229 int main() { 230 MlirContext ctx = mlirContextCreate(); 231 mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx); 232 testTypeHierarchy(ctx); 233 testAnyQuantizedType(ctx); 234 testUniformType(ctx); 235 testUniformPerAxisType(ctx); 236 testCalibratedType(ctx); 237 mlirContextDestroy(ctx); 238 return EXIT_SUCCESS; 239 } 240