#include "mlir-c/Dialect/Quant.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/IR.h"
#include <assert.h>
#include <inttypes.h>
#include <stdio.h>
#include <stdlib.h>
static void testTypeHierarchy(MlirContext ctx) {
fprintf(stderr, "testTypeHierarchy\n");
MlirType i8 = mlirIntegerTypeGet(ctx, 8);
MlirType any = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
MlirType uniform =
mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
"!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
MlirType perAxis = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString(
"!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
MlirType calibrated = mlirTypeParseGet(
ctx,
mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType");
assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType");
assert(!mlirTypeIsNull(perAxis) &&
"couldn't parse UniformQuantizedPerAxisType");
assert(!mlirTypeIsNull(calibrated) &&
"couldn't parse CalibratedQuantizedType");
fprintf(stderr, "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(i8));
fprintf(stderr, "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(any));
fprintf(stderr, "uniform isa QuantizedType: %d\n",
mlirTypeIsAQuantizedType(uniform));
fprintf(stderr, "perAxis isa QuantizedType: %d\n",
mlirTypeIsAQuantizedType(perAxis));
fprintf(stderr, "calibrated isa QuantizedType: %d\n",
mlirTypeIsAQuantizedType(calibrated));
fprintf(stderr, "any isa AnyQuantizedType: %d\n",
mlirTypeIsAAnyQuantizedType(any));
fprintf(stderr, "uniform isa UniformQuantizedType: %d\n",
mlirTypeIsAUniformQuantizedType(uniform));
fprintf(stderr, "perAxis isa UniformQuantizedPerAxisType: %d\n",
mlirTypeIsAUniformQuantizedPerAxisType(perAxis));
fprintf(stderr, "calibrated isa CalibratedQuantizedType: %d\n",
mlirTypeIsACalibratedQuantizedType(calibrated));
fprintf(stderr, "perAxis isa UniformQuantizedType: %d\n",
mlirTypeIsAUniformQuantizedType(perAxis));
fprintf(stderr, "uniform isa CalibratedQuantizedType: %d\n",
mlirTypeIsACalibratedQuantizedType(uniform));
fprintf(stderr, "\n");
}
void testAnyQuantizedType(MlirContext ctx) {
fprintf(stderr, "testAnyQuantizedType\n");
MlirType anyParsed = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
MlirType i8 = mlirIntegerTypeGet(ctx, 8);
MlirType f32 = mlirF32TypeGet(ctx);
MlirType any =
mlirAnyQuantizedTypeGet(mlirQuantizedTypeGetSignedFlag(), i8, f32, -8, 7);
fprintf(stderr, "flags: %u\n", mlirQuantizedTypeGetFlags(any));
fprintf(stderr, "signed: %u\n", mlirQuantizedTypeIsSigned(any));
fprintf(stderr, "storage type: ");
mlirTypeDump(mlirQuantizedTypeGetStorageType(any));
fprintf(stderr, "\n");
fprintf(stderr, "expressed type: ");
mlirTypeDump(mlirQuantizedTypeGetExpressedType(any));
fprintf(stderr, "\n");
fprintf(stderr, "storage min: %" PRId64 "\n",
mlirQuantizedTypeGetStorageTypeMin(any));
fprintf(stderr, "storage max: %" PRId64 "\n",
mlirQuantizedTypeGetStorageTypeMax(any));
fprintf(stderr, "storage width: %u\n",
mlirQuantizedTypeGetStorageTypeIntegralWidth(any));
fprintf(stderr, "quantized element type: ");
mlirTypeDump(mlirQuantizedTypeGetQuantizedElementType(any));
fprintf(stderr, "\n");
fprintf(stderr, "equal: %d\n", mlirTypeEqual(anyParsed, any));
mlirTypeDump(any);
fprintf(stderr, "\n\n");
}
void testUniformType(MlirContext ctx) {
fprintf(stderr, "testUniformType\n");
MlirType uniformParsed =
mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
"!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
MlirType i8 = mlirIntegerTypeGet(ctx, 8);
MlirType f32 = mlirF32TypeGet(ctx);
MlirType uniform = mlirUniformQuantizedTypeGet(
mlirQuantizedTypeGetSignedFlag(), i8, f32, 0.99872, 127, -8, 7);
fprintf(stderr, "scale: %lf\n", mlirUniformQuantizedTypeGetScale(uniform));
fprintf(stderr, "zero point: %" PRId64 "\n",
mlirUniformQuantizedTypeGetZeroPoint(uniform));
fprintf(stderr, "fixed point: %d\n",
mlirUniformQuantizedTypeIsFixedPoint(uniform));
fprintf(stderr, "equal: %d\n", mlirTypeEqual(uniform, uniformParsed));
mlirTypeDump(uniform);
fprintf(stderr, "\n\n");
}
void testUniformPerAxisType(MlirContext ctx) {
fprintf(stderr, "testUniformPerAxisType\n");
MlirType perAxisParsed = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString(
"!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
MlirType i8 = mlirIntegerTypeGet(ctx, 8);
MlirType f32 = mlirF32TypeGet(ctx);
double scales[] = {200.0, 0.99872};
int64_t zeroPoints[] = {0, 120};
MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet(
mlirQuantizedTypeGetSignedFlag(), i8, f32,
2, scales, zeroPoints,
1,
mlirQuantizedTypeGetDefaultMinimumForInteger(true,
8),
mlirQuantizedTypeGetDefaultMaximumForInteger(true,
8));
fprintf(stderr, "num dims: %" PRIdPTR "\n",
mlirUniformQuantizedPerAxisTypeGetNumDims(perAxis));
fprintf(stderr, "scale 0: %lf\n",
mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 0));
fprintf(stderr, "scale 1: %lf\n",
mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 1));
fprintf(stderr, "zero point 0: %" PRId64 "\n",
mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 0));
fprintf(stderr, "zero point 1: %" PRId64 "\n",
mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 1));
fprintf(stderr, "quantized dim: %" PRId32 "\n",
mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(perAxis));
fprintf(stderr, "fixed point: %d\n",
mlirUniformQuantizedPerAxisTypeIsFixedPoint(perAxis));
fprintf(stderr, "equal: %d\n", mlirTypeEqual(perAxis, perAxisParsed));
mlirTypeDump(perAxis);
fprintf(stderr, "\n\n");
}
void testCalibratedType(MlirContext ctx) {
fprintf(stderr, "testCalibratedType\n");
MlirType calibratedParsed = mlirTypeParseGet(
ctx,
mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
MlirType f32 = mlirF32TypeGet(ctx);
MlirType calibrated = mlirCalibratedQuantizedTypeGet(f32, -0.998, 1.2321);
fprintf(stderr, "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(calibrated));
fprintf(stderr, "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(calibrated));
fprintf(stderr, "equal: %d\n", mlirTypeEqual(calibrated, calibratedParsed));
mlirTypeDump(calibrated);
fprintf(stderr, "\n\n");
}
int main(void) {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx);
testTypeHierarchy(ctx);
testAnyQuantizedType(ctx);
testUniformType(ctx);
testUniformPerAxisType(ctx);
testCalibratedType(ctx);
mlirContextDestroy(ctx);
return EXIT_SUCCESS;
}