910e62b5创建于 1月15日历史提交
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef SERVICES_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDER_EXECUTION_TASK_H_
#define SERVICES_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDER_EXECUTION_TASK_H_

#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"

namespace passage_embeddings {

using OutputType = std::vector<float>;
using InputType = const std::vector<int>&;

class PassageEmbedderExecutionTask
    : public tflite::task::core::BaseTaskApi<OutputType, InputType> {
 public:
  explicit PassageEmbedderExecutionTask(
      std::unique_ptr<tflite::task::core::TfLiteEngine> tflite_engine);
  ~PassageEmbedderExecutionTask() override;

  std::optional<OutputType> Execute(InputType input);

 protected:
  // BaseTaskApi:
  absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
                          InputType input) override;
  tflite::support::StatusOr<OutputType> Postprocess(
      const std::vector<const TfLiteTensor*>& output_tensors,
      InputType input) override;
};

}  // namespace passage_embeddings

#endif  // SERVICES_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDER_EXECUTION_TASK_H_