#ifndef SERVICES_ON_DEVICE_MODEL_ANDROID_BACKEND_SESSION_IMPL_ANDROID_H_
#define SERVICES_ON_DEVICE_MODEL_ANDROID_BACKEND_SESSION_IMPL_ANDROID_H_
#include <string>
#include "base/android/scoped_java_ref.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "components/optimization_guide/proto/model_execution.pb.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/on_device_model/android/sequence_checker_helper.h"
#include "services/on_device_model/backend_session.h"
namespace on_device_model {
class BackendSessionImplAndroid : public BackendSession {
public:
enum class GenerateResult {
kSuccess = 0,
kUnknownError = 1,
kApiNotAvailable = 2,
kFeatureIsNull = 3,
kGetFeatureError = 4,
kInferenceGeneralError = 5,
kInferenceRequestProcessingError = 6,
kInferenceResponseProcessingError = 7,
kMaxValue = kInferenceResponseProcessingError,
};
BackendSessionImplAndroid(
optimization_guide::proto::ModelExecutionFeature feature,
on_device_model::mojom::SessionParamsPtr params);
~BackendSessionImplAndroid() override;
void Append(on_device_model::mojom::AppendOptionsPtr options,
mojo::PendingRemote<on_device_model::mojom::ContextClient> client,
base::OnceClosure on_complete) override;
void Generate(
on_device_model::mojom::GenerateOptionsPtr input,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder> response,
base::OnceClosure on_complete) override;
void SizeInTokens(on_device_model::mojom::InputPtr input,
base::OnceCallback<void(uint32_t)> callback) override;
void Score(const std::string& text,
base::OnceCallback<void(float)> callback) override;
void GetProbabilitiesBlocking(
const std::string& input,
base::OnceCallback<void(const std::vector<float>&)> callback) override;
std::unique_ptr<BackendSession> Clone() override;
void AsrStream(on_device_model::mojom::AsrStreamOptionsPtr options,
mojo::PendingRemote<on_device_model::mojom::AsrStreamResponder>
response) override;
void AsrAddAudioChunk(on_device_model::mojom::AudioDataPtr data) override;
void OnResponse(const std::string& response);
void OnComplete(GenerateResult generate_result);
private:
BackendSessionImplAndroid(
optimization_guide::proto::ModelExecutionFeature feature,
on_device_model::mojom::SessionParamsPtr params,
const std::vector<ml::InputPiece>& context_input_pieces);
void OnResponseOnSequence(const std::string& response);
void OnCompleteOnSequence(GenerateResult generate_result);
base::android::ScopedJavaGlobalRef<jobject> java_session_;
mojo::Remote<on_device_model::mojom::StreamingResponder> responder_;
std::vector<ml::InputPiece> context_input_pieces_;
const optimization_guide::proto::ModelExecutionFeature feature_;
on_device_model::mojom::SessionParamsPtr params_;
SEQUENCE_CHECKER(sequence_checker_);
SequenceCheckerHelper sequence_checker_helper_;
base::WeakPtr<BackendSessionImplAndroid> weak_ptr_;
base::WeakPtrFactory<BackendSessionImplAndroid> weak_factory_{this};
};
}
#endif