373ce0a3创建于 4月20日历史提交
From b7c954c7594b9a39df5604dbadf6e59b6f98b1e9 Mon Sep 17 00:00:00 2001
From: xinhaiteng <xinhaiteng@xiaomi.com>
Date: Wed, 24 Apr 2024 09:18:52 +0800
Subject: [PATCH] tflite: add quantize float32 to int8 operator

VELAPLATFO-30990

Separate the 'Float32 to Int8' implementation of the 'quantize' operator to reduce the code size.

Change-Id: I83dc809d36587741f3087b25e83e454b30f9626b
Signed-off-by: xinhaiteng <xinhaiteng@xiaomi.com>
---
 tensorflow/lite/micro/kernels/quantize.cc        |  5 +++++
 tensorflow/lite/micro/kernels/quantize.h         |  4 ++++
 tensorflow/lite/micro/kernels/quantize_common.cc | 16 ++++++++++++++++
 .../lite/micro/micro_mutable_op_resolver.h       |  7 ++++---
 4 files changed, 29 insertions(+), 3 deletions(-)

diff --git a/tensorflow/lite/micro/kernels/quantize.cc b/tensorflow/lite/micro/kernels/quantize.cc
index ba11f19b6..a6ff4f055 100644
--- a/tensorflow/lite/micro/kernels/quantize.cc
+++ b/tensorflow/lite/micro/kernels/quantize.cc
@@ -39,4 +39,9 @@ TFLMRegistration Register_QUANTIZE() {
       InitQuantizeReference, PrepareQuantizeReference, EvalQuantizeReference);
 }
 
+TFLMRegistration Register_QUANTIZE_FLOAT32_INT8() {
+  return tflite::micro::RegisterOp(
+      InitQuantizeReference, PrepareQuantizeReference, EvalQuantizeReferenceFloat32ToInt8);
+}
+
 }  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/quantize.h b/tensorflow/lite/micro/kernels/quantize.h
index ba93809a2..b97e4a2ab 100644
--- a/tensorflow/lite/micro/kernels/quantize.h
+++ b/tensorflow/lite/micro/kernels/quantize.h
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/micro/micro_common.h"
 
 namespace tflite {
 
@@ -31,7 +32,10 @@ struct OpDataQuantizeReference {
 };
 
 TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node);
+TfLiteStatus EvalQuantizeReferenceFloat32ToInt8(TfLiteContext* context, TfLiteNode* node);
 TfLiteStatus PrepareQuantizeReference(TfLiteContext* context, TfLiteNode* node);
+
+TFLMRegistration Register_QUANTIZE_FLOAT32_INT8();
 }  // namespace tflite
 
 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_QUANTIZE_H_
diff --git a/tensorflow/lite/micro/kernels/quantize_common.cc b/tensorflow/lite/micro/kernels/quantize_common.cc
index cb04eafce..6bdd0df39 100644
--- a/tensorflow/lite/micro/kernels/quantize_common.cc
+++ b/tensorflow/lite/micro/kernels/quantize_common.cc
@@ -236,4 +236,20 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
   return kTfLiteOk;
 }
 
+TfLiteStatus EvalQuantizeReferenceFloat32ToInt8(TfLiteContext* context, TfLiteNode* node) {
+  TFLITE_DCHECK(node->user_data != nullptr);
+  auto* data = static_cast<OpDataQuantizeReference*>(node->user_data);
+
+  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
+  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
+
+  TFLITE_DCHECK(input->type == kTfLiteFloat32 && output->type == kTfLiteInt8);
+  reference_ops::AffineQuantize(
+      data->quantization_params, tflite::micro::GetTensorShape(input),
+      tflite::micro::GetTensorData<float>(input),
+      tflite::micro::GetTensorShape(output),
+      tflite::micro::GetTensorData<int8_t>(output));
+  return kTfLiteOk;
+}
+
 }  // namespace tflite
diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h
index b3e3cbcf8..ec8979c4e 100644
--- a/tensorflow/lite/micro/micro_mutable_op_resolver.h
+++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h
@@ -32,6 +32,7 @@ limitations under the License.
 #include "tensorflow/lite/micro/kernels/micro_ops.h"
 #include "tensorflow/lite/micro/kernels/mul.h"
 #include "tensorflow/lite/micro/kernels/pooling.h"
+#include "tensorflow/lite/micro/kernels/quantize.h"
 #include "tensorflow/lite/micro/kernels/reduce.h"
 #include "tensorflow/lite/micro/kernels/softmax.h"
 #include "tensorflow/lite/micro/kernels/transpose_conv.h"
@@ -481,9 +482,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
                       ParsePrelu);
   }
 
-  TfLiteStatus AddQuantize() {
-    return AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(),
-                      ParseQuantize);
+  TfLiteStatus AddQuantize(
+      const TFLMRegistration& registration = Register_QUANTIZE()) {
+    return AddBuiltin(BuiltinOperator_QUANTIZE, registration, ParseQuantize);
   }
 
   TfLiteStatus AddReadVariable() {
-- 
2.25.1