#include "media/webrtc/audio_processor.h"
#include <inttypes.h>
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <array>
#include <limits>
#include <memory>
#include <optional>
#include <utility>
#include "base/containers/heap_array.h"
#include "base/feature_list.h"
#include "base/functional/callback_helpers.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/strings/stringprintf.h"
#include "base/strings/to_string.h"
#include "base/task/thread_pool.h"
#include "base/trace_event/trace_event.h"
#include "build/build_config.h"
#include "build/chromecast_buildflags.h"
#include "media/base/audio_bus.h"
#include "media/base/audio_fifo.h"
#include "media/base/audio_parameters.h"
#include "media/base/audio_timestamp_helper.h"
#include "media/base/channel_layout.h"
#include "media/base/limits.h"
#include "media/base/media_switches.h"
#include "media/webrtc/constants.h"
#include "media/webrtc/helpers.h"
#include "media/webrtc/webrtc_features.h"
#include "third_party/tflite/src/tensorflow/lite/model_builder.h"
#include "third_party/webrtc/modules/audio_processing/include/audio_processing.h"
#include "third_party/webrtc_overrides/task_queue_factory.h"
namespace media {
namespace {
constexpr int kBuffersPerSecond = 100;
int GetCaptureBufferSize(bool need_webrtc_processing,
const AudioParameters device_format) {
#if (BUILDFLAG(IS_ANDROID) && !BUILDFLAG(IS_CAST_ANDROID)) || BUILDFLAG(IS_OHOS)
return 2 * device_format.sample_rate() / 100;
#else
const int buffer_size_10_ms = device_format.sample_rate() / 100;
if (need_webrtc_processing) {
DCHECK_EQ(buffer_size_10_ms, webrtc::AudioProcessing::GetFrameSize(
device_format.sample_rate()));
return buffer_size_10_ms;
}
if (int hardware_buffer_size = device_format.frames_per_buffer()) {
return hardware_buffer_size;
}
return buffer_size_10_ms;
#endif
}
bool ApmNeedsPlayoutReference(const webrtc::AudioProcessing* apm) {
if (!apm) {
return false;
}
const webrtc::AudioProcessing::Config config = apm->GetConfig();
const bool aec = config.echo_canceller.enabled;
const bool legacy_agc =
config.gain_controller1.enabled &&
!config.gain_controller1.analog_gain_controller.enabled;
return aec || legacy_agc;
}
}
class AudioProcessorCaptureBus {
public:
AudioProcessorCaptureBus(int channels, int frames)
: bus_(media::AudioBus::Create(channels, frames)),
channel_ptrs_(
base::HeapArray<float*>::WithSize(static_cast<size_t>(channels))) {
bus_->Zero();
}
media::AudioBus* bus() { return bus_.get(); }
base::span<float* const> channel_ptrs() {
for (int i = 0; i < bus_->channels(); ++i) {
channel_ptrs_[i] = bus_->channel_span(i).data();
}
return channel_ptrs_;
}
private:
std::unique_ptr<media::AudioBus> bus_;
base::HeapArray<float*> channel_ptrs_;
};
class AudioProcessorCaptureFifo {
public:
AudioProcessorCaptureFifo(int source_channels,
int destination_channels,
int source_frames,
int destination_frames,
int sample_rate)
:
#if DCHECK_IS_ON()
source_channels_(source_channels),
source_frames_(source_frames),
#endif
sample_rate_(sample_rate),
destination_(
std::make_unique<AudioProcessorCaptureBus>(destination_channels,
destination_frames)),
data_available_(false) {
DCHECK_GE(source_channels, destination_channels);
if (source_channels > destination_channels) {
audio_source_intermediate_ =
media::AudioBus::CreateWrapper(destination_channels);
}
if (source_frames != destination_frames) {
const int fifo_frames = 2 * std::max(source_frames, destination_frames);
fifo_ =
std::make_unique<media::AudioFifo>(destination_channels, fifo_frames);
}
}
void Push(const media::AudioBus& source, base::TimeDelta audio_delay) {
#if DCHECK_IS_ON()
DCHECK_EQ(source.channels(), source_channels_);
DCHECK_EQ(source.frames(), source_frames_);
#endif
const media::AudioBus* source_to_push = &source;
if (audio_source_intermediate_) {
audio_source_intermediate_->set_frames(source.frames());
audio_source_intermediate_->SetAllChannels(source.AllChannels());
source_to_push = audio_source_intermediate_.get();
}
if (fifo_) {
CHECK_LT(fifo_->frames(),
static_cast<size_t>(destination_->bus()->frames()));
next_audio_delay_ =
audio_delay + fifo_->frames() * base::Seconds(1) / sample_rate_;
fifo_->Push(source_to_push);
} else {
CHECK(!data_available_);
source_to_push->CopyTo(destination_->bus());
next_audio_delay_ = audio_delay;
data_available_ = true;
}
}
bool Consume(AudioProcessorCaptureBus** destination,
base::TimeDelta* audio_delay) {
if (fifo_) {
if (fifo_->frames() <
static_cast<size_t>(destination_->bus()->frames())) {
return false;
}
fifo_->Consume(destination_->bus(), 0, destination_->bus()->frames());
*audio_delay = next_audio_delay_;
next_audio_delay_ -=
destination_->bus()->frames() * base::Seconds(1) / sample_rate_;
} else {
if (!data_available_) {
return false;
}
*audio_delay = next_audio_delay_;
data_available_ = false;
}
*destination = destination_.get();
return true;
}
private:
#if DCHECK_IS_ON()
const int source_channels_;
const int source_frames_;
#endif
const int sample_rate_;
std::unique_ptr<media::AudioBus> audio_source_intermediate_;
std::unique_ptr<AudioProcessorCaptureBus> destination_;
std::unique_ptr<media::AudioFifo> fifo_;
base::TimeDelta next_audio_delay_;
bool data_available_;
};
std::unique_ptr<AudioProcessor> AudioProcessor::Create(
DeliverProcessedAudioCallback deliver_processed_audio_callback,
LogCallback log_callback,
const AudioProcessingSettings& settings,
const media::AudioParameters& input_format,
const media::AudioParameters& output_format,
raw_ptr<const tflite::FlatBufferModel>
neural_residual_echo_estimator_model) {
log_callback.Run(base::StringPrintf(
"AudioProcessor::Create({multi_channel_capture_processing=%s})",
base::ToString(settings.multi_channel_capture_processing)));
auto [webrtc_audio_processing, added_aec_delay] =
media::CreateWebRtcAudioProcessingModule(
settings, neural_residual_echo_estimator_model);
return std::make_unique<AudioProcessor>(
std::move(deliver_processed_audio_callback), std::move(log_callback),
input_format, output_format, std::move(webrtc_audio_processing),
ApmNeedsPlayoutReference(webrtc_audio_processing.get()), added_aec_delay);
}
AudioProcessor::AudioProcessor(
DeliverProcessedAudioCallback deliver_processed_audio_callback,
LogCallback log_callback,
const media::AudioParameters& input_format,
const media::AudioParameters& output_format,
webrtc::scoped_refptr<webrtc::AudioProcessing> webrtc_audio_processing,
bool needs_playout_reference,
base::TimeDelta added_aec_delay)
: webrtc_audio_processing_(webrtc_audio_processing),
needs_playout_reference_(needs_playout_reference),
log_callback_(std::move(log_callback)),
added_aec_delay_(added_aec_delay),
input_format_(input_format),
output_format_(output_format),
deliver_processed_audio_callback_(
std::move(deliver_processed_audio_callback)),
audio_delay_stats_reporter_(kBuffersPerSecond),
playout_fifo_(
base::BindRepeating(&AudioProcessor::AnalyzePlayoutData,
base::Unretained(this))) {
DCHECK(deliver_processed_audio_callback_);
DCHECK(log_callback_);
CHECK(input_format_.IsValid());
CHECK(output_format_.IsValid());
if (webrtc_audio_processing_) {
DCHECK_EQ(
webrtc::AudioProcessing::GetFrameSize(output_format_.sample_rate()),
output_format_.frames_per_buffer());
}
if (input_format_.sample_rate() % 100 != 0 ||
output_format_.sample_rate() % 100 != 0) {
SendLogMessage(base::StringPrintf(
"%s: WARNING: Sample rate not divisible by 100, processing is provided "
"on a best-effort basis. input rate=[%d], output rate=[%d]",
__func__, input_format_.sample_rate(), output_format_.sample_rate()));
}
SendLogMessage(base::StringPrintf(
"%s({input_format_=[%s], output_format_=[%s], added_aec_delay=[%d]})",
__func__, input_format_.AsHumanReadableString().c_str(),
output_format_.AsHumanReadableString().c_str(),
added_aec_delay_.InMilliseconds()));
const int fifo_output_frames_per_buffer =
webrtc_audio_processing_
? webrtc::AudioProcessing::GetFrameSize(input_format_.sample_rate())
: output_format_.frames_per_buffer();
SendLogMessage(base::StringPrintf(
"%s => (capture FIFO: fifo_output_frames_per_buffer=%d)", __func__,
fifo_output_frames_per_buffer));
capture_fifo_ = std::make_unique<AudioProcessorCaptureFifo>(
input_format.channels(), input_format_.channels(),
input_format.frames_per_buffer(), fifo_output_frames_per_buffer,
input_format.sample_rate());
if (webrtc_audio_processing_) {
output_bus_ = std::make_unique<AudioProcessorCaptureBus>(
output_format_.channels(), output_format.frames_per_buffer());
}
}
AudioProcessor::~AudioProcessor() {
DCHECK_CALLED_ON_VALID_SEQUENCE(owning_sequence_);
OnStopDump();
}
void AudioProcessor::ProcessCapturedAudio(const media::AudioBus& audio_source,
base::TimeTicks audio_capture_time,
int num_preferred_channels,
double volume) {
DCHECK(deliver_processed_audio_callback_);
DCHECK(input_format_.IsValid());
DCHECK_EQ(audio_source.channels(), input_format_.channels());
DCHECK_EQ(audio_source.frames(), input_format_.frames_per_buffer());
base::TimeDelta capture_delay = base::TimeTicks::Now() - audio_capture_time;
TRACE_EVENT("audio", "AudioProcessor::ProcessCapturedAudio",
"capture_time (ms)",
(audio_capture_time - base::TimeTicks()).InMillisecondsF(),
"capture_delay (ms)", capture_delay.InMillisecondsF());
capture_fifo_->Push(audio_source, capture_delay);
AudioProcessorCaptureBus* process_bus;
while (capture_fifo_->Consume(&process_bus, &capture_delay)) {
AudioProcessorCaptureBus* output_bus = process_bus;
std::optional<double> new_volume;
if (webrtc_audio_processing_) {
output_bus = output_bus_.get();
new_volume = ProcessData(process_bus->channel_ptrs(),
process_bus->bus()->frames(), capture_delay,
volume, num_preferred_channels, output_bus);
}
deliver_processed_audio_callback_.Run(
*output_bus->bus(), audio_capture_time - added_aec_delay_, new_volume);
}
}
void AudioProcessor::SetOutputWillBeMuted(bool muted) {
DCHECK_CALLED_ON_VALID_SEQUENCE(owning_sequence_);
SendLogMessage(
base::StringPrintf("%s({muted=%s})", __func__, base::ToString(muted)));
if (webrtc_audio_processing_) {
webrtc_audio_processing_->set_output_will_be_muted(muted);
}
}
void AudioProcessor::OnStartDump(base::File dump_file) {
DCHECK_CALLED_ON_VALID_SEQUENCE(owning_sequence_);
DCHECK(dump_file.IsValid());
if (webrtc_audio_processing_) {
if (!worker_queue_) {
worker_queue_ =
CreateWebRtcTaskQueue(webrtc::TaskQueueFactory::Priority::LOW);
}
media::StartEchoCancellationDump(webrtc_audio_processing_.get(),
std::move(dump_file), worker_queue_.get());
} else {
base::ThreadPool::PostTask(
FROM_HERE, {base::TaskPriority::LOWEST, base::MayBlock()},
base::DoNothingWithBoundArgs(std::move(dump_file)));
}
}
void AudioProcessor::OnStopDump() {
DCHECK_CALLED_ON_VALID_SEQUENCE(owning_sequence_);
if (!worker_queue_) {
return;
}
if (webrtc_audio_processing_) {
media::StopEchoCancellationDump(webrtc_audio_processing_.get());
}
worker_queue_ = nullptr;
}
void AudioProcessor::OnPlayoutData(const AudioBus& audio_bus,
int sample_rate,
base::TimeDelta audio_delay) {
TRACE_EVENT1("audio", "AudioProcessor::OnPlayoutData", "playout_delay (ms)",
audio_delay.InMillisecondsF());
if (!webrtc_audio_processing_) {
return;
}
unbuffered_playout_delay_ = audio_delay;
if (!playout_sample_rate_hz_ || sample_rate != *playout_sample_rate_hz_) {
playout_sample_rate_hz_ = sample_rate;
const int samples_per_channel =
webrtc::AudioProcessing::GetFrameSize(sample_rate);
playout_fifo_.Reset(samples_per_channel);
}
playout_fifo_.Push(audio_bus);
}
void AudioProcessor::AnalyzePlayoutData(const AudioBus& audio_bus,
int frame_delay) {
DCHECK(webrtc_audio_processing_);
DCHECK(playout_sample_rate_hz_.has_value());
const base::TimeDelta playout_delay =
unbuffered_playout_delay_ +
AudioTimestampHelper::FramesToTime(frame_delay, *playout_sample_rate_hz_);
playout_delay_ = playout_delay;
TRACE_EVENT("audio", "AudioProcessor::AnalyzePlayoutData", "delay (frames)",
frame_delay, "playout_delay (ms)",
playout_delay.InMillisecondsF());
webrtc::StreamConfig input_stream_config(*playout_sample_rate_hz_,
audio_bus.channels());
std::array<const float*, media::limits::kMaxChannels> input_ptrs;
for (int i = 0; i < audio_bus.channels(); ++i) {
input_ptrs[i] = audio_bus.channel_span(i).data();
}
const int apm_error = webrtc_audio_processing_->AnalyzeReverseStream(
input_ptrs.data(), input_stream_config);
if (apm_error != webrtc::AudioProcessing::kNoError &&
apm_playout_error_code_log_count_ < 10) {
LOG(ERROR) << "MSAP::OnPlayoutData: AnalyzeReverseStream error="
<< apm_error;
++apm_playout_error_code_log_count_;
}
}
webrtc::AudioProcessingStats AudioProcessor::GetStats() {
if (!webrtc_audio_processing_) {
return {};
}
return webrtc_audio_processing_->GetStatistics();
}
std::optional<double> AudioProcessor::ProcessData(
base::span<const float* const> process_ptrs,
int process_frames,
base::TimeDelta capture_delay,
double volume,
int num_preferred_channels,
AudioProcessorCaptureBus* output_bus) {
DCHECK(webrtc_audio_processing_);
const base::TimeDelta playout_delay = playout_delay_;
TRACE_EVENT2("audio", "AudioProcessor::ProcessData", "capture_delay (ms)",
capture_delay.InMillisecondsF(), "playout_delay (ms)",
playout_delay.InMillisecondsF());
const int64_t total_delay_ms =
(capture_delay + playout_delay).InMilliseconds();
if (total_delay_ms > 300 && large_delay_log_count_ < 10) {
LOG(WARNING) << "Large audio delay, capture delay: "
<< capture_delay.InMillisecondsF()
<< "ms; playout delay: " << playout_delay.InMillisecondsF()
<< "ms";
++large_delay_log_count_;
}
audio_delay_stats_reporter_.ReportDelay(capture_delay, playout_delay);
webrtc::AudioProcessing* ap = webrtc_audio_processing_.get();
DCHECK_LE(total_delay_ms, std::numeric_limits<int>::max());
ap->set_stream_delay_ms(base::saturated_cast<int>(total_delay_ms));
max_num_preferred_output_channels_ =
std::max(max_num_preferred_output_channels_, num_preferred_channels);
#if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_MAC)
DCHECK_LE(volume, 1.0);
#elif BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_OPENBSD)
DCHECK_LE(volume, 1.6);
#endif
const int max_analog_gain_level = media::MaxWebRtcAnalogGainLevel();
int current_analog_gain_level =
static_cast<int>((volume * max_analog_gain_level) + 0.5);
current_analog_gain_level =
std::min(current_analog_gain_level, max_analog_gain_level);
DCHECK_LE(current_analog_gain_level, max_analog_gain_level);
ap->set_stream_analog_level(current_analog_gain_level);
int num_apm_output_channels =
std::min(max_num_preferred_output_channels_, output_format_.channels());
num_apm_output_channels = std::min(num_apm_output_channels, 2);
CHECK_GE(num_apm_output_channels, 1);
const webrtc::StreamConfig apm_output_config = webrtc::StreamConfig(
output_format_.sample_rate(), num_apm_output_channels);
int err =
ap->ProcessStream(process_ptrs.data(), CreateStreamConfig(input_format_),
apm_output_config, output_bus->channel_ptrs().data());
DCHECK_EQ(err, 0) << "ProcessStream() error: " << err;
if (num_apm_output_channels < output_format_.channels()) {
if (num_apm_output_channels == 1) {
CHECK_GE(output_bus->bus()->channels(), 2);
output_bus->bus()->channel_span(1).copy_from_nonoverlapping(
output_bus->bus()->channel_span(0));
}
}
const int recommended_analog_gain_level =
ap->recommended_stream_analog_level();
if (recommended_analog_gain_level == current_analog_gain_level) {
return std::nullopt;
} else {
return static_cast<double>(recommended_analog_gain_level) /
media::MaxWebRtcAnalogGainLevel();
}
}
void AudioProcessor::SendLogMessage(const std::string& message) {
log_callback_.Run(base::StringPrintf("MSAP::%s [this=0x%" PRIXPTR "]",
message.c_str(),
reinterpret_cast<uintptr_t>(this)));
}
std::optional<AudioParameters> AudioProcessor::ComputeInputFormat(
const AudioParameters& device_format,
const AudioProcessingSettings& audio_processing_settings) {
const ChannelLayout channel_layout = device_format.channel_layout();
if (channel_layout != CHANNEL_LAYOUT_MONO &&
channel_layout != CHANNEL_LAYOUT_STEREO &&
channel_layout != CHANNEL_LAYOUT_DISCRETE) {
return std::nullopt;
}
AudioParameters params(
device_format.format(), device_format.channel_layout_config(),
device_format.sample_rate(),
GetCaptureBufferSize(
audio_processing_settings.NeedWebrtcAudioProcessing(),
device_format));
params.set_effects(device_format.effects());
if (channel_layout == CHANNEL_LAYOUT_DISCRETE) {
DCHECK_LE(device_format.channels(), 2);
}
DVLOG(1) << params.AsHumanReadableString();
CHECK(params.IsValid());
return params;
}
AudioParameters AudioProcessor::GetDefaultOutputFormat(
const AudioParameters& input_format,
const AudioProcessingSettings& settings) {
const bool need_webrtc_audio_processing =
settings.NeedWebrtcAudioProcessing();
const int output_sample_rate =
need_webrtc_audio_processing
?
#if BUILDFLAG(IS_CASTOS) || BUILDFLAG(IS_CAST_ANDROID)
std::min(media::WebRtcAudioProcessingSampleRateHz(),
input_format.sample_rate())
#else
media::WebRtcAudioProcessingSampleRateHz()
#endif
: input_format.sample_rate();
media::ChannelLayoutConfig output_channel_layout_config;
if (!need_webrtc_audio_processing) {
output_channel_layout_config = input_format.channel_layout_config();
} else if (settings.multi_channel_capture_processing) {
output_channel_layout_config = input_format.channel_layout_config();
} else {
output_channel_layout_config = ChannelLayoutConfig::Mono();
}
int output_frames = webrtc::AudioProcessing::GetFrameSize(output_sample_rate);
if (!need_webrtc_audio_processing &&
input_format.frames_per_buffer() < output_frames) {
output_frames = input_format.frames_per_buffer();
}
media::AudioParameters output_format = media::AudioParameters(
input_format.format(), output_channel_layout_config, output_sample_rate,
output_frames);
return output_format;
}
}