#include "chrome/services/speech/audio_source_fetcher_impl.h"
#include <memory>
#include <utility>
#include "base/files/file_path.h"
#include "base/memory/raw_ptr.h"
#include "base/path_service.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "chrome/services/speech/cros_speech_recognition_recognizer_impl.h"
#include "chrome/services/speech/speech_recognition_service_impl.h"
#include "media/base/audio_bus.h"
#include "media/base/audio_glitch_info.h"
#include "media/base/audio_timestamp_helper.h"
#include "media/mojo/mojom/audio_data.mojom.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "services/audio/public/cpp/fake_stream_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace speech {
namespace {
constexpr int kServerBasedRecognitionAudioSampleRate = 16000;
constexpr int kServerBasedRecognitionAudioFramesPerBuffer = 1600;
constexpr int kOriginalSampleRate = 48000;
constexpr int kOriginalFramesPerBuffer = 9600;
constexpr char kServerBasedRecognitionSessionLength[] =
"Ash.SpeechRecognitionSessionLength.ServerBased";
constexpr char kOnDeviceRecognitionSessionLength[] =
"Ash.SpeechRecognitionSessionLength.OnDevice";
}
class MockStreamFactory : public audio::FakeStreamFactory {
public:
MockStreamFactory() = default;
~MockStreamFactory() override = default;
void CreateInputStream(
mojo::PendingReceiver<media::mojom::AudioInputStream> stream_receiver,
mojo::PendingRemote<media::mojom::AudioInputStreamClient> client,
mojo::PendingRemote<media::mojom::AudioInputStreamObserver> observer,
mojo::PendingRemote<media::mojom::AudioLog> log,
const std::string& device_id,
const media::AudioParameters& params,
const base::UnguessableToken& group_id,
uint32_t shared_memory_count,
bool enable_agc,
media::mojom::AudioProcessingConfigPtr processing_config,
CreateInputStreamCallback created_callback) override {
last_created_callback_ = std::move(created_callback);
}
private:
CreateInputStreamCallback last_created_callback_;
mojo::Receiver<media::mojom::AudioStreamFactory> receiver_{this};
};
using OnSendAudioToSpeechRecognitionCallback =
base::OnceCallback<void(media::mojom::AudioDataS16Ptr buffer)>;
class MockAudioSourceConsumer : public AudioSourceConsumer {
public:
MockAudioSourceConsumer() = default;
~MockAudioSourceConsumer() override = default;
void AddAudio(media::mojom::AudioDataS16Ptr buffer) override {
EXPECT_FALSE(is_audio_end_);
std::move(on_send_audio_to_speech_recognition_callback_)
.Run(std::move(buffer));
}
void OnAudioCaptureEnd() override { is_audio_end_ = true; }
void OnAudioCaptureError() override {}
void SetOnSendAudioToSpeechRecognitionCallback(
OnSendAudioToSpeechRecognitionCallback callback) {
on_send_audio_to_speech_recognition_callback_ = std::move(callback);
}
private:
OnSendAudioToSpeechRecognitionCallback
on_send_audio_to_speech_recognition_callback_;
bool is_audio_end_ = false;
};
class AudioSourceFetcherImplTest
: public testing::TestWithParam<bool>,
public media::mojom::SpeechRecognitionRecognizerClient {
public:
AudioSourceFetcherImplTest() { is_server_based_ = GetParam(); }
~AudioSourceFetcherImplTest() override = default;
void SetUp() override {
std::unique_ptr<MockAudioSourceConsumer> speech_recognition_recognizer =
std::make_unique<MockAudioSourceConsumer>();
speech_recognition_recognizer_ = speech_recognition_recognizer.get();
audio_source_fetcher_ = std::make_unique<AudioSourceFetcherImpl>(
std::move(speech_recognition_recognizer),
true,
is_server_based_);
}
protected:
bool is_server_based() const { return is_server_based_; }
AudioSourceFetcherImpl* audio_source_fetcher() {
return audio_source_fetcher_.get();
}
void SetOnSendAudioToSpeechRecognitionCallback(
OnSendAudioToSpeechRecognitionCallback callback) {
speech_recognition_recognizer_->SetOnSendAudioToSpeechRecognitionCallback(
std::move(callback));
}
void VerifyAudioBuffer(int sample_rate, int frame_count, bool stop = false) {
base::RunLoop run_loop;
SetOnSendAudioToSpeechRecognitionCallback(
base::BindLambdaForTesting([&](media::mojom::AudioDataS16Ptr buffer) {
EXPECT_EQ(sample_rate, buffer->sample_rate);
EXPECT_EQ(frame_count, buffer->frame_count);
run_loop.Quit();
}));
if (stop) {
audio_source_fetcher()->Stop();
}
run_loop.Run();
}
void OnSpeechRecognitionRecognitionEvent(
const media::SpeechRecognitionResult& result,
OnSpeechRecognitionRecognitionEventCallback reply) override {}
void OnSpeechRecognitionStopped() override {}
void OnSpeechRecognitionError() override {}
void OnLanguageIdentificationEvent(
media::mojom::LanguageIdentificationEventPtr event) override {}
base::test::TaskEnvironment task_environment_;
std::unique_ptr<AudioSourceFetcherImpl> audio_source_fetcher_;
base::HistogramTester histogram_tester_;
private:
raw_ptr<MockAudioSourceConsumer, DanglingUntriaged>
speech_recognition_recognizer_;
bool is_server_based_;
};
TEST_P(AudioSourceFetcherImplTest, Resample) {
MockStreamFactory fake_stream_factory;
media::AudioParameters params =
media::AudioParameters(media::AudioParameters::AUDIO_PCM_LOW_LATENCY,
media::ChannelLayoutConfig::Stereo(),
kOriginalSampleRate,
kOriginalFramesPerBuffer);
audio_source_fetcher()->Start(fake_stream_factory.MakeRemote(), "device_id",
params);
std::unique_ptr<::media::AudioBus> audio_bus =
media::AudioBus::Create(params);
audio_bus->Zero();
audio_source_fetcher()->Capture(audio_bus.get(),
base::TimeTicks::Now(),
{},
1.0);
if (is_server_based()) {
VerifyAudioBuffer(kServerBasedRecognitionAudioSampleRate,
kServerBasedRecognitionAudioFramesPerBuffer);
} else {
VerifyAudioBuffer(kOriginalSampleRate, kOriginalFramesPerBuffer);
}
fake_stream_factory.ResetReceiver();
audio_source_fetcher()->Stop();
if (is_server_based()) {
VerifyAudioBuffer(kServerBasedRecognitionAudioSampleRate,
kServerBasedRecognitionAudioFramesPerBuffer);
}
audio_source_fetcher_.reset();
base::TimeDelta length = media::AudioTimestampHelper::FramesToTime(
audio_bus->frames(), kOriginalSampleRate);
const auto* histogram_name = is_server_based()
? kServerBasedRecognitionSessionLength
: kOnDeviceRecognitionSessionLength;
histogram_tester_.ExpectTimeBucketCount(histogram_name, length,
1);
}
TEST_P(AudioSourceFetcherImplTest, StopDuringResample) {
MockStreamFactory fake_stream_factory;
media::AudioParameters params =
media::AudioParameters(media::AudioParameters::AUDIO_PCM_LOW_LATENCY,
media::ChannelLayoutConfig::Stereo(),
kOriginalSampleRate,
kOriginalFramesPerBuffer);
audio_source_fetcher()->Start(fake_stream_factory.MakeRemote(), "device_id",
params);
auto audio_bus = media::AudioBus::Create(params);
audio_bus->Zero();
audio_source_fetcher()->Capture(audio_bus.get(),
base::TimeTicks::Now(),
{},
1.0);
if (is_server_based()) {
audio_source_fetcher_->Stop();
task_environment_.RunUntilIdle();
} else {
VerifyAudioBuffer(kOriginalSampleRate, kOriginalFramesPerBuffer,
true);
}
audio_source_fetcher_.reset();
fake_stream_factory.ResetReceiver();
}
INSTANTIATE_TEST_SUITE_P(All, AudioSourceFetcherImplTest, ::testing::Bool());
}