1 //===- Quant.cpp - C Interface for Quant dialect --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/Quant.h" 10 #include "mlir/CAPI/Registration.h" 11 #include "mlir/Dialect/Quant/QuantOps.h" 12 #include "mlir/Dialect/Quant/QuantTypes.h" 13 14 using namespace mlir; 15 16 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) 17 18 //===---------------------------------------------------------------------===// 19 // QuantizedType 20 //===---------------------------------------------------------------------===// 21 22 bool mlirTypeIsAQuantizedType(MlirType type) { 23 return unwrap(type).isa<quant::QuantizedType>(); 24 } 25 26 unsigned mlirQuantizedTypeGetSignedFlag() { 27 return quant::QuantizationFlags::Signed; 28 } 29 30 int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned, 31 unsigned integralWidth) { 32 return quant::QuantizedType::getDefaultMinimumForInteger(isSigned, 33 integralWidth); 34 } 35 36 int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, 37 unsigned integralWidth) { 38 return quant::QuantizedType::getDefaultMaximumForInteger(isSigned, 39 integralWidth); 40 } 41 42 MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { 43 return wrap(unwrap(type).cast<quant::QuantizedType>().getExpressedType()); 44 } 45 46 unsigned mlirQuantizedTypeGetFlags(MlirType type) { 47 return unwrap(type).cast<quant::QuantizedType>().getFlags(); 48 } 49 50 bool mlirQuantizedTypeIsSigned(MlirType type) { 51 return unwrap(type).cast<quant::QuantizedType>().isSigned(); 52 } 53 54 MlirType mlirQuantizedTypeGetStorageType(MlirType type) { 55 return wrap(unwrap(type).cast<quant::QuantizedType>().getStorageType()); 56 } 57 58 int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { 59 return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMin(); 60 } 61 62 int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { 63 return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMax(); 64 } 65 66 unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { 67 return unwrap(type) 68 .cast<quant::QuantizedType>() 69 .getStorageTypeIntegralWidth(); 70 } 71 72 bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, 73 MlirType candidate) { 74 return unwrap(type).cast<quant::QuantizedType>().isCompatibleExpressedType( 75 unwrap(candidate)); 76 } 77 78 MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { 79 return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type))); 80 } 81 82 MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, 83 MlirType candidate) { 84 return wrap(unwrap(type).cast<quant::QuantizedType>().castFromStorageType( 85 unwrap(candidate))); 86 } 87 88 MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { 89 return wrap(quant::QuantizedType::castToStorageType( 90 unwrap(type).cast<quant::QuantizedType>())); 91 } 92 93 MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, 94 MlirType candidate) { 95 return wrap(unwrap(type).cast<quant::QuantizedType>().castFromExpressedType( 96 unwrap(candidate))); 97 } 98 99 MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { 100 return wrap(quant::QuantizedType::castToExpressedType(unwrap(type))); 101 } 102 103 MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, 104 MlirType candidate) { 105 return wrap( 106 unwrap(type).cast<quant::QuantizedType>().castExpressedToStorageType( 107 unwrap(candidate))); 108 } 109 110 //===---------------------------------------------------------------------===// 111 // AnyQuantizedType 112 //===---------------------------------------------------------------------===// 113 114 bool mlirTypeIsAAnyQuantizedType(MlirType type) { 115 return unwrap(type).isa<quant::AnyQuantizedType>(); 116 } 117 118 MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, 119 MlirType expressedType, int64_t storageTypeMin, 120 int64_t storageTypeMax) { 121 return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType), 122 unwrap(expressedType), 123 storageTypeMin, storageTypeMax)); 124 } 125 126 //===---------------------------------------------------------------------===// 127 // UniformQuantizedType 128 //===---------------------------------------------------------------------===// 129 130 bool mlirTypeIsAUniformQuantizedType(MlirType type) { 131 return unwrap(type).isa<quant::UniformQuantizedType>(); 132 } 133 134 MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, 135 MlirType expressedType, double scale, 136 int64_t zeroPoint, int64_t storageTypeMin, 137 int64_t storageTypeMax) { 138 return wrap(quant::UniformQuantizedType::get( 139 flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint, 140 storageTypeMin, storageTypeMax)); 141 } 142 143 double mlirUniformQuantizedTypeGetScale(MlirType type) { 144 return unwrap(type).cast<quant::UniformQuantizedType>().getScale(); 145 } 146 147 int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { 148 return unwrap(type).cast<quant::UniformQuantizedType>().getZeroPoint(); 149 } 150 151 bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { 152 return unwrap(type).cast<quant::UniformQuantizedType>().isFixedPoint(); 153 } 154 155 //===---------------------------------------------------------------------===// 156 // UniformQuantizedPerAxisType 157 //===---------------------------------------------------------------------===// 158 159 bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { 160 return unwrap(type).isa<quant::UniformQuantizedPerAxisType>(); 161 } 162 163 MlirType mlirUniformQuantizedPerAxisTypeGet( 164 unsigned flags, MlirType storageType, MlirType expressedType, 165 intptr_t nDims, double *scales, int64_t *zeroPoints, 166 int32_t quantizedDimension, int64_t storageTypeMin, 167 int64_t storageTypeMax) { 168 return wrap(quant::UniformQuantizedPerAxisType::get( 169 flags, unwrap(storageType), unwrap(expressedType), 170 llvm::makeArrayRef(scales, nDims), llvm::makeArrayRef(zeroPoints, nDims), 171 quantizedDimension, storageTypeMin, storageTypeMax)); 172 } 173 174 intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { 175 return unwrap(type) 176 .cast<quant::UniformQuantizedPerAxisType>() 177 .getScales() 178 .size(); 179 } 180 181 double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { 182 return unwrap(type) 183 .cast<quant::UniformQuantizedPerAxisType>() 184 .getScales()[pos]; 185 } 186 187 int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, 188 intptr_t pos) { 189 return unwrap(type) 190 .cast<quant::UniformQuantizedPerAxisType>() 191 .getZeroPoints()[pos]; 192 } 193 194 int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { 195 return unwrap(type) 196 .cast<quant::UniformQuantizedPerAxisType>() 197 .getQuantizedDimension(); 198 } 199 200 bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { 201 return unwrap(type).cast<quant::UniformQuantizedPerAxisType>().isFixedPoint(); 202 } 203 204 //===---------------------------------------------------------------------===// 205 // CalibratedQuantizedType 206 //===---------------------------------------------------------------------===// 207 208 bool mlirTypeIsACalibratedQuantizedType(MlirType type) { 209 return unwrap(type).isa<quant::CalibratedQuantizedType>(); 210 } 211 212 MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, 213 double max) { 214 return wrap( 215 quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max)); 216 } 217 218 double mlirCalibratedQuantizedTypeGetMin(MlirType type) { 219 return unwrap(type).cast<quant::CalibratedQuantizedType>().getMin(); 220 } 221 222 double mlirCalibratedQuantizedTypeGetMax(MlirType type) { 223 return unwrap(type).cast<quant::CalibratedQuantizedType>().getMax(); 224 } 225