Bump to WebRTC M120 release

Some API deprecation -- ExperimentalAgc and ExperimentalNs are gone.
We're continuing to carry iSAC even though it's gone upstream, but maybe
we'll want to drop that soon.
This commit is contained in:
Arun Raghavan
2023-12-12 10:42:58 -05:00
parent 9a202fb8c2
commit c6abf6cd3f
479 changed files with 20900 additions and 11996 deletions

View File

@ -8,48 +8,39 @@
import("../../../webrtc.gni")
group("agc2") {
deps = [
":adaptive_digital",
":fixed_digital",
]
}
rtc_library("level_estimation_agc") {
rtc_library("speech_level_estimator") {
sources = [
"adaptive_mode_level_estimator_agc.cc",
"adaptive_mode_level_estimator_agc.h",
"speech_level_estimator.cc",
"speech_level_estimator.h",
]
visibility = [
"..:gain_controller2",
"./*",
]
configs += [ "..:apm_debug_dump" ]
deps = [
":adaptive_digital",
":common",
":gain_applier",
":noise_level_estimator",
":rnn_vad_with_level",
"..:api",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:logging",
"../../../rtc_base:safe_minmax",
"../agc:level_estimation",
"../vad",
]
}
rtc_library("adaptive_digital") {
rtc_library("adaptive_digital_gain_controller") {
sources = [
"adaptive_agc.cc",
"adaptive_agc.h",
"adaptive_digital_gain_applier.cc",
"adaptive_digital_gain_applier.h",
"adaptive_mode_level_estimator.cc",
"adaptive_mode_level_estimator.h",
"saturation_protector.cc",
"saturation_protector.h",
"adaptive_digital_gain_controller.cc",
"adaptive_digital_gain_controller.h",
]
visibility = [
"..:gain_controller2",
"./*",
]
configs += [ "..:apm_debug_dump" ]
@ -57,20 +48,39 @@ rtc_library("adaptive_digital") {
deps = [
":common",
":gain_applier",
":noise_level_estimator",
":rnn_vad_with_level",
"..:api",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:logging",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_compare",
"../../../rtc_base:safe_minmax",
"../../../system_wrappers:metrics",
]
}
rtc_library("saturation_protector") {
sources = [
"saturation_protector.cc",
"saturation_protector.h",
"saturation_protector_buffer.cc",
"saturation_protector_buffer.h",
]
visibility = [
"..:gain_controller2",
"./*",
]
configs += [ "..:apm_debug_dump" ]
deps = [
":common",
"..:apm_logging",
"../../../rtc_base:checks",
"../../../rtc_base:safe_compare",
"../../../rtc_base:safe_minmax",
]
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
}
@ -83,10 +93,36 @@ rtc_library("biquad_filter") {
]
deps = [
"../../../api:array_view",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:macromagic",
]
}
rtc_library("clipping_predictor") {
visibility = [
"../agc:agc",
"./*",
]
sources = [
"clipping_predictor.cc",
"clipping_predictor.h",
"clipping_predictor_level_buffer.cc",
"clipping_predictor_level_buffer.h",
]
deps = [
":gain_map",
"..:api",
"..:audio_frame_view",
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:logging",
"../../../rtc_base:safe_minmax",
]
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
}
rtc_source_set("common") {
sources = [ "agc2_common.h" ]
}
@ -101,6 +137,12 @@ rtc_library("fixed_digital") {
"limiter.h",
]
visibility = [
"..:gain_controller2",
"../../audio_mixer:audio_mixer_impl",
"./*",
]
configs += [ "..:apm_debug_dump" ]
deps = [
@ -111,10 +153,12 @@ rtc_library("fixed_digital") {
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:gtest_prod",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_conversions",
"../../../rtc_base:safe_minmax",
"../../../rtc_base:stringutils",
"../../../system_wrappers:metrics",
]
absl_deps = [ "//third_party/abseil-cpp/absl/strings" ]
}
rtc_library("gain_applier") {
@ -122,6 +166,12 @@ rtc_library("gain_applier") {
"gain_applier.cc",
"gain_applier.h",
]
visibility = [
"..:gain_controller2",
"./*",
]
deps = [
":common",
"..:audio_frame_view",
@ -130,39 +180,94 @@ rtc_library("gain_applier") {
]
}
rtc_source_set("gain_map") {
visibility = [
"..:analog_mic_simulation",
"../agc:agc",
"./*",
]
sources = [ "gain_map_internal.h" ]
}
rtc_library("input_volume_controller") {
sources = [
"input_volume_controller.cc",
"input_volume_controller.h",
"speech_probability_buffer.cc",
"speech_probability_buffer.h",
]
visibility = [
"..:gain_controller2",
"./*",
]
configs += [ "..:apm_debug_dump" ]
deps = [
":clipping_predictor",
":gain_map",
":input_volume_stats_reporter",
"..:api",
"..:audio_buffer",
"..:audio_frame_view",
"../../../api:array_view",
"../../../rtc_base:checks",
"../../../rtc_base:checks",
"../../../rtc_base:gtest_prod",
"../../../rtc_base:gtest_prod",
"../../../rtc_base:logging",
"../../../rtc_base:safe_minmax",
"../../../system_wrappers:field_trial",
"../../../system_wrappers:metrics",
]
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
}
rtc_library("noise_level_estimator") {
sources = [
"down_sampler.cc",
"down_sampler.h",
"noise_level_estimator.cc",
"noise_level_estimator.h",
"noise_spectrum_estimator.cc",
"noise_spectrum_estimator.h",
"signal_classifier.cc",
"signal_classifier.h",
]
deps = [
":biquad_filter",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../common_audio/third_party/ooura:fft_size_128",
"../../../rtc_base:checks",
"../../../rtc_base:macromagic",
"../../../system_wrappers",
]
visibility = [
"..:gain_controller2",
"./*",
]
configs += [ "..:apm_debug_dump" ]
}
rtc_library("rnn_vad_with_level") {
rtc_library("vad_wrapper") {
sources = [
"vad_with_level.cc",
"vad_with_level.h",
"vad_wrapper.cc",
"vad_wrapper.h",
]
visibility = [
"..:gain_controller2",
"./*",
]
defines = []
if (rtc_build_with_neon && current_cpu != "arm64") {
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
cflags = [ "-mfpu=neon" ]
}
deps = [
":common",
":cpu_features",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
@ -172,28 +277,85 @@ rtc_library("rnn_vad_with_level") {
]
}
rtc_library("adaptive_digital_unittests") {
rtc_library("cpu_features") {
sources = [
"cpu_features.cc",
"cpu_features.h",
]
visibility = [
"..:gain_controller2",
"./*",
]
deps = [
"../../../rtc_base:stringutils",
"../../../rtc_base/system:arch",
"../../../system_wrappers",
]
}
rtc_library("speech_level_estimator_unittest") {
testonly = true
configs += [ "..:apm_debug_dump" ]
sources = [ "speech_level_estimator_unittest.cc" ]
deps = [
":common",
":speech_level_estimator",
"..:api",
"..:apm_logging",
"../../../rtc_base:gunit_helpers",
"../../../test:test_support",
]
}
rtc_library("adaptive_digital_gain_controller_unittest") {
testonly = true
configs += [ "..:apm_debug_dump" ]
sources = [ "adaptive_digital_gain_controller_unittest.cc" ]
deps = [
":adaptive_digital_gain_controller",
":common",
":test_utils",
"..:api",
"..:apm_logging",
"..:audio_frame_view",
"../../../common_audio",
"../../../rtc_base:gunit_helpers",
"../../../test:test_support",
]
}
rtc_library("gain_applier_unittest") {
testonly = true
configs += [ "..:apm_debug_dump" ]
sources = [ "gain_applier_unittest.cc" ]
deps = [
":gain_applier",
":test_utils",
"..:audio_frame_view",
"../../../rtc_base:gunit_helpers",
"../../../test:test_support",
]
}
rtc_library("saturation_protector_unittest") {
testonly = true
configs += [ "..:apm_debug_dump" ]
sources = [
"adaptive_digital_gain_applier_unittest.cc",
"adaptive_mode_level_estimator_unittest.cc",
"gain_applier_unittest.cc",
"saturation_protector_buffer_unittest.cc",
"saturation_protector_unittest.cc",
]
deps = [
":adaptive_digital",
":common",
":gain_applier",
":test_utils",
":saturation_protector",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:rtc_base_approved",
"../../../test:test_support",
]
}
@ -232,38 +394,67 @@ rtc_library("fixed_digital_unittests") {
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:rtc_base_approved",
"../../../system_wrappers:metrics",
]
}
rtc_library("input_volume_controller_unittests") {
testonly = true
sources = [
"clipping_predictor_level_buffer_unittest.cc",
"clipping_predictor_unittest.cc",
"input_volume_controller_unittest.cc",
"speech_probability_buffer_unittest.cc",
]
configs += [ "..:apm_debug_dump" ]
deps = [
":clipping_predictor",
":gain_map",
":input_volume_controller",
"..:api",
"../../../api:array_view",
"../../../rtc_base:checks",
"../../../rtc_base:random",
"../../../rtc_base:safe_conversions",
"../../../rtc_base:safe_minmax",
"../../../rtc_base:stringutils",
"../../../system_wrappers:metrics",
"../../../test:field_trial",
"../../../test:fileutils",
"../../../test:test_support",
"//testing/gtest",
]
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
}
rtc_library("noise_estimator_unittests") {
testonly = true
configs += [ "..:apm_debug_dump" ]
sources = [
"noise_level_estimator_unittest.cc",
"signal_classifier_unittest.cc",
]
sources = [ "noise_level_estimator_unittest.cc" ]
deps = [
":noise_level_estimator",
":test_utils",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../api:function_view",
"../../../rtc_base:checks",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:rtc_base_approved",
]
}
rtc_library("rnn_vad_with_level_unittests") {
rtc_library("vad_wrapper_unittests") {
testonly = true
sources = [ "vad_with_level_unittest.cc" ]
sources = [ "vad_wrapper_unittest.cc" ]
deps = [
":common",
":rnn_vad_with_level",
":vad_wrapper",
"..:audio_frame_view",
"../../../rtc_base:checks",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:safe_compare",
"../../../test:test_support",
@ -285,6 +476,36 @@ rtc_library("test_utils") {
deps = [
"..:audio_frame_view",
"../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:random",
]
}
rtc_library("input_volume_stats_reporter") {
sources = [
"input_volume_stats_reporter.cc",
"input_volume_stats_reporter.h",
]
deps = [
"../../../rtc_base:gtest_prod",
"../../../rtc_base:logging",
"../../../rtc_base:safe_minmax",
"../../../rtc_base:stringutils",
"../../../system_wrappers:metrics",
]
absl_deps = [
"//third_party/abseil-cpp/absl/strings",
"//third_party/abseil-cpp/absl/types:optional",
]
}
rtc_library("input_volume_stats_reporter_unittests") {
testonly = true
sources = [ "input_volume_stats_reporter_unittest.cc" ]
deps = [
":input_volume_stats_reporter",
"../../../rtc_base:stringutils",
"../../../system_wrappers:metrics",
"../../../test:test_support",
]
absl_deps = [ "//third_party/abseil-cpp/absl/strings" ]
}

View File

@ -1,90 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_agc.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
namespace webrtc {
namespace {
void DumpDebugData(const AdaptiveDigitalGainApplier::FrameInfo& info,
ApmDataDumper& dumper) {
dumper.DumpRaw("agc2_vad_probability", info.vad_result.speech_probability);
dumper.DumpRaw("agc2_vad_rms_dbfs", info.vad_result.rms_dbfs);
dumper.DumpRaw("agc2_vad_peak_dbfs", info.vad_result.peak_dbfs);
dumper.DumpRaw("agc2_noise_estimate_dbfs", info.input_noise_level_dbfs);
dumper.DumpRaw("agc2_last_limiter_audio_level", info.limiter_envelope_dbfs);
}
constexpr int kGainApplierAdjacentSpeechFramesThreshold = 1;
constexpr float kMaxGainChangePerSecondDb = 3.f;
constexpr float kMaxOutputNoiseLevelDbfs = -50.f;
} // namespace
AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper)
: speech_level_estimator_(apm_data_dumper),
gain_applier_(apm_data_dumper,
kGainApplierAdjacentSpeechFramesThreshold,
kMaxGainChangePerSecondDb,
kMaxOutputNoiseLevelDbfs),
apm_data_dumper_(apm_data_dumper),
noise_level_estimator_(apm_data_dumper) {
RTC_DCHECK(apm_data_dumper);
}
AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2& config)
: speech_level_estimator_(
apm_data_dumper,
config.adaptive_digital.level_estimator,
config.adaptive_digital
.level_estimator_adjacent_speech_frames_threshold,
config.adaptive_digital.initial_saturation_margin_db,
config.adaptive_digital.extra_saturation_margin_db),
vad_(config.adaptive_digital.vad_probability_attack),
gain_applier_(
apm_data_dumper,
config.adaptive_digital.gain_applier_adjacent_speech_frames_threshold,
config.adaptive_digital.max_gain_change_db_per_second,
config.adaptive_digital.max_output_noise_level_dbfs),
apm_data_dumper_(apm_data_dumper),
noise_level_estimator_(apm_data_dumper) {
RTC_DCHECK(apm_data_dumper);
if (!config.adaptive_digital.use_saturation_protector) {
RTC_LOG(LS_WARNING) << "The saturation protector cannot be disabled.";
}
}
AdaptiveAgc::~AdaptiveAgc() = default;
void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
AdaptiveDigitalGainApplier::FrameInfo info;
info.vad_result = vad_.AnalyzeFrame(frame);
speech_level_estimator_.Update(info.vad_result);
info.input_level_dbfs = speech_level_estimator_.level_dbfs();
info.input_noise_level_dbfs = noise_level_estimator_.Analyze(frame);
info.limiter_envelope_dbfs =
limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.f;
info.estimate_is_confident = speech_level_estimator_.IsConfident();
DumpDebugData(info, *apm_data_dumper_);
gain_applier_.Process(info, frame);
}
void AdaptiveAgc::Reset() {
speech_level_estimator_.Reset();
}
} // namespace webrtc

View File

@ -1,50 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/noise_level_estimator.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/include/audio_processing.h"
namespace webrtc {
class ApmDataDumper;
// Adaptive digital gain controller.
// TODO(crbug.com/webrtc/7494): Unify with `AdaptiveDigitalGainApplier`.
class AdaptiveAgc {
public:
explicit AdaptiveAgc(ApmDataDumper* apm_data_dumper);
// TODO(crbug.com/webrtc/7494): Remove ctor above.
AdaptiveAgc(ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2& config);
~AdaptiveAgc();
// Analyzes `frame` and applies a digital adaptive gain to it. Takes into
// account the envelope measured by the limiter.
// TODO(crbug.com/webrtc/7494): Make the class depend on the limiter.
void Process(AudioFrameView<float> frame, float limiter_envelope);
void Reset();
private:
AdaptiveModeLevelEstimator speech_level_estimator_;
VadLevelAnalyzer vad_;
AdaptiveDigitalGainApplier gain_applier_;
ApmDataDumper* const apm_data_dumper_;
NoiseLevelEstimator noise_level_estimator_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_

View File

@ -1,179 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
#include <algorithm>
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
namespace {
// This function maps input level to desired applied gain. We want to
// boost the signal so that peaks are at -kHeadroomDbfs. We can't
// apply more than kMaxGainDb gain.
float ComputeGainDb(float input_level_dbfs) {
// If the level is very low, boost it as much as we can.
if (input_level_dbfs < -(kHeadroomDbfs + kMaxGainDb)) {
return kMaxGainDb;
}
// We expect to end up here most of the time: the level is below
// -headroom, but we can boost it to -headroom.
if (input_level_dbfs < -kHeadroomDbfs) {
return -kHeadroomDbfs - input_level_dbfs;
}
// Otherwise, the level is too high and we can't boost. The
// LevelEstimator is responsible for not reporting bogus gain
// values.
RTC_DCHECK_LE(input_level_dbfs, 0.f);
return 0.f;
}
// Returns `target_gain` if the output noise level is below
// `max_output_noise_level_dbfs`; otherwise returns a capped gain so that the
// output noise level equals `max_output_noise_level_dbfs`.
float LimitGainByNoise(float target_gain,
float input_noise_level_dbfs,
float max_output_noise_level_dbfs,
ApmDataDumper& apm_data_dumper) {
const float noise_headroom_db =
max_output_noise_level_dbfs - input_noise_level_dbfs;
apm_data_dumper.DumpRaw("agc2_noise_headroom_db", noise_headroom_db);
return std::min(target_gain, std::max(noise_headroom_db, 0.f));
}
float LimitGainByLowConfidence(float target_gain,
float last_gain,
float limiter_audio_level_dbfs,
bool estimate_is_confident) {
if (estimate_is_confident ||
limiter_audio_level_dbfs <= kLimiterThresholdForAgcGainDbfs) {
return target_gain;
}
const float limiter_level_before_gain = limiter_audio_level_dbfs - last_gain;
// Compute a new gain so that limiter_level_before_gain + new_gain <=
// kLimiterThreshold.
const float new_target_gain = std::max(
kLimiterThresholdForAgcGainDbfs - limiter_level_before_gain, 0.f);
return std::min(new_target_gain, target_gain);
}
// Computes how the gain should change during this frame.
// Return the gain difference in db to 'last_gain_db'.
float ComputeGainChangeThisFrameDb(float target_gain_db,
float last_gain_db,
bool gain_increase_allowed,
float max_gain_change_db) {
float target_gain_difference_db = target_gain_db - last_gain_db;
if (!gain_increase_allowed) {
target_gain_difference_db = std::min(target_gain_difference_db, 0.f);
}
return rtc::SafeClamp(target_gain_difference_db, -max_gain_change_db,
max_gain_change_db);
}
} // namespace
AdaptiveDigitalGainApplier::AdaptiveDigitalGainApplier(
ApmDataDumper* apm_data_dumper,
int adjacent_speech_frames_threshold,
float max_gain_change_db_per_second,
float max_output_noise_level_dbfs)
: apm_data_dumper_(apm_data_dumper),
gain_applier_(
/*hard_clip_samples=*/false,
/*initial_gain_factor=*/DbToRatio(kInitialAdaptiveDigitalGainDb)),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
max_gain_change_db_per_10ms_(max_gain_change_db_per_second *
kFrameDurationMs / 1000.f),
max_output_noise_level_dbfs_(max_output_noise_level_dbfs),
calls_since_last_gain_log_(0),
frames_to_gain_increase_allowed_(adjacent_speech_frames_threshold_),
last_gain_db_(kInitialAdaptiveDigitalGainDb) {
RTC_DCHECK_GT(max_gain_change_db_per_second, 0.f);
RTC_DCHECK_GE(frames_to_gain_increase_allowed_, 1);
RTC_DCHECK_GE(max_output_noise_level_dbfs_, -90.f);
RTC_DCHECK_LE(max_output_noise_level_dbfs_, 0.f);
}
void AdaptiveDigitalGainApplier::Process(const FrameInfo& info,
AudioFrameView<float> frame) {
RTC_DCHECK_GE(info.input_level_dbfs, -150.f);
RTC_DCHECK_GE(frame.num_channels(), 1);
RTC_DCHECK(
frame.samples_per_channel() == 80 || frame.samples_per_channel() == 160 ||
frame.samples_per_channel() == 320 || frame.samples_per_channel() == 480)
<< "`frame` does not look like a 10 ms frame for an APM supported sample "
"rate";
const float target_gain_db = LimitGainByLowConfidence(
LimitGainByNoise(ComputeGainDb(std::min(info.input_level_dbfs, 0.f)),
info.input_noise_level_dbfs,
max_output_noise_level_dbfs_, *apm_data_dumper_),
last_gain_db_, info.limiter_envelope_dbfs, info.estimate_is_confident);
// Forbid increasing the gain until enough adjacent speech frames are
// observed.
if (info.vad_result.speech_probability < kVadConfidenceThreshold) {
frames_to_gain_increase_allowed_ = adjacent_speech_frames_threshold_;
} else if (frames_to_gain_increase_allowed_ > 0) {
frames_to_gain_increase_allowed_--;
}
const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb(
target_gain_db, last_gain_db_,
/*gain_increase_allowed=*/frames_to_gain_increase_allowed_ == 0,
max_gain_change_db_per_10ms_);
apm_data_dumper_->DumpRaw("agc2_want_to_change_by_db",
target_gain_db - last_gain_db_);
apm_data_dumper_->DumpRaw("agc2_will_change_by_db",
gain_change_this_frame_db);
// Optimization: avoid calling math functions if gain does not
// change.
if (gain_change_this_frame_db != 0.f) {
gain_applier_.SetGainFactor(
DbToRatio(last_gain_db_ + gain_change_this_frame_db));
}
gain_applier_.ApplyGain(frame);
// Remember that the gain has changed for the next iteration.
last_gain_db_ = last_gain_db_ + gain_change_this_frame_db;
apm_data_dumper_->DumpRaw("agc2_applied_gain_db", last_gain_db_);
// Log every 10 seconds.
calls_since_last_gain_log_++;
if (calls_since_last_gain_log_ == 1000) {
calls_since_last_gain_log_ = 0;
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.DigitalGainApplied",
last_gain_db_, 0, kMaxGainDb, kMaxGainDb + 1);
RTC_HISTOGRAM_COUNTS_LINEAR(
"WebRTC.Audio.Agc2.EstimatedSpeechPlusNoiseLevel",
-info.input_level_dbfs, 0, 100, 101);
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedNoiseLevel",
-info.input_noise_level_dbfs, 0, 100, 101);
RTC_LOG(LS_INFO) << "AGC2 adaptive digital"
<< " | speech_plus_noise_dbfs: " << info.input_level_dbfs
<< " | noise_dbfs: " << info.input_noise_level_dbfs
<< " | gain_db: " << last_gain_db_;
}
}
} // namespace webrtc

View File

@ -1,69 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_
#include "modules/audio_processing/agc2/gain_applier.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class ApmDataDumper;
// Part of the adaptive digital controller that applies a digital adaptive gain.
// The gain is updated towards a target. The logic decides when gain updates are
// allowed, it controls the adaptation speed and caps the target based on the
// estimated noise level and the speech level estimate confidence.
class AdaptiveDigitalGainApplier {
public:
// Information about a frame to process.
struct FrameInfo {
float input_level_dbfs; // Estimated speech plus noise level.
float input_noise_level_dbfs; // Estimated noise level.
VadLevelAnalyzer::Result vad_result;
float limiter_envelope_dbfs; // Envelope level from the limiter.
bool estimate_is_confident;
};
// Ctor.
// `adjacent_speech_frames_threshold` indicates how many speech frames are
// required before a gain increase is allowed. `max_gain_change_db_per_second`
// limits the adaptation speed (uniformly operated across frames).
// `max_output_noise_level_dbfs` limits the output noise level.
AdaptiveDigitalGainApplier(ApmDataDumper* apm_data_dumper,
int adjacent_speech_frames_threshold,
float max_gain_change_db_per_second,
float max_output_noise_level_dbfs);
AdaptiveDigitalGainApplier(const AdaptiveDigitalGainApplier&) = delete;
AdaptiveDigitalGainApplier& operator=(const AdaptiveDigitalGainApplier&) =
delete;
// Analyzes `info`, updates the digital gain and applies it to a 10 ms
// `frame`. Supports any sample rate supported by APM.
void Process(const FrameInfo& info, AudioFrameView<float> frame);
private:
ApmDataDumper* const apm_data_dumper_;
GainApplier gain_applier_;
const int adjacent_speech_frames_threshold_;
const float max_gain_change_db_per_10ms_;
const float max_output_noise_level_dbfs_;
int calls_since_last_gain_log_;
int frames_to_gain_increase_allowed_;
float last_gain_db_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_

View File

@ -0,0 +1,216 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_digital_gain_controller.h"
#include <algorithm>
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
namespace {
using AdaptiveDigitalConfig =
AudioProcessing::Config::GainController2::AdaptiveDigital;
constexpr int kHeadroomHistogramMin = 0;
constexpr int kHeadroomHistogramMax = 50;
constexpr int kGainDbHistogramMax = 30;
// Computes the gain for `input_level_dbfs` to reach `-config.headroom_db`.
// Clamps the gain in [0, `config.max_gain_db`]. `config.headroom_db` is a
// safety margin to allow transient peaks to exceed the target peak level
// without clipping.
float ComputeGainDb(float input_level_dbfs,
const AdaptiveDigitalConfig& config) {
// If the level is very low, apply the maximum gain.
if (input_level_dbfs < -(config.headroom_db + config.max_gain_db)) {
return config.max_gain_db;
}
// We expect to end up here most of the time: the level is below
// -headroom, but we can boost it to -headroom.
if (input_level_dbfs < -config.headroom_db) {
return -config.headroom_db - input_level_dbfs;
}
// The level is too high and we can't boost.
RTC_DCHECK_GE(input_level_dbfs, -config.headroom_db);
return 0.0f;
}
// Returns `target_gain_db` if applying such a gain to `input_noise_level_dbfs`
// does not exceed `max_output_noise_level_dbfs`. Otherwise lowers and returns
// `target_gain_db` so that the output noise level equals
// `max_output_noise_level_dbfs`.
float LimitGainByNoise(float target_gain_db,
float input_noise_level_dbfs,
float max_output_noise_level_dbfs,
ApmDataDumper& apm_data_dumper) {
const float max_allowed_gain_db =
max_output_noise_level_dbfs - input_noise_level_dbfs;
apm_data_dumper.DumpRaw("agc2_adaptive_gain_applier_max_allowed_gain_db",
max_allowed_gain_db);
return std::min(target_gain_db, std::max(max_allowed_gain_db, 0.0f));
}
float LimitGainByLowConfidence(float target_gain_db,
float last_gain_db,
float limiter_audio_level_dbfs,
bool estimate_is_confident) {
if (estimate_is_confident ||
limiter_audio_level_dbfs <= kLimiterThresholdForAgcGainDbfs) {
return target_gain_db;
}
const float limiter_level_dbfs_before_gain =
limiter_audio_level_dbfs - last_gain_db;
// Compute a new gain so that `limiter_level_dbfs_before_gain` +
// `new_target_gain_db` is not great than `kLimiterThresholdForAgcGainDbfs`.
const float new_target_gain_db = std::max(
kLimiterThresholdForAgcGainDbfs - limiter_level_dbfs_before_gain, 0.0f);
return std::min(new_target_gain_db, target_gain_db);
}
// Computes how the gain should change during this frame.
// Return the gain difference in db to 'last_gain_db'.
float ComputeGainChangeThisFrameDb(float target_gain_db,
float last_gain_db,
bool gain_increase_allowed,
float max_gain_decrease_db,
float max_gain_increase_db) {
RTC_DCHECK_GT(max_gain_decrease_db, 0);
RTC_DCHECK_GT(max_gain_increase_db, 0);
float target_gain_difference_db = target_gain_db - last_gain_db;
if (!gain_increase_allowed) {
target_gain_difference_db = std::min(target_gain_difference_db, 0.0f);
}
return rtc::SafeClamp(target_gain_difference_db, -max_gain_decrease_db,
max_gain_increase_db);
}
} // namespace
AdaptiveDigitalGainController::AdaptiveDigitalGainController(
ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
int adjacent_speech_frames_threshold)
: apm_data_dumper_(apm_data_dumper),
gain_applier_(
/*hard_clip_samples=*/false,
/*initial_gain_factor=*/DbToRatio(config.initial_gain_db)),
config_(config),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
max_gain_change_db_per_10ms_(config_.max_gain_change_db_per_second *
kFrameDurationMs / 1000.0f),
calls_since_last_gain_log_(0),
frames_to_gain_increase_allowed_(adjacent_speech_frames_threshold),
last_gain_db_(config_.initial_gain_db) {
RTC_DCHECK_GT(max_gain_change_db_per_10ms_, 0.0f);
RTC_DCHECK_GE(frames_to_gain_increase_allowed_, 1);
RTC_DCHECK_GE(config_.max_output_noise_level_dbfs, -90.0f);
RTC_DCHECK_LE(config_.max_output_noise_level_dbfs, 0.0f);
}
void AdaptiveDigitalGainController::Process(const FrameInfo& info,
AudioFrameView<float> frame) {
RTC_DCHECK_GE(info.speech_level_dbfs, -150.0f);
RTC_DCHECK_GE(frame.num_channels(), 1);
RTC_DCHECK(
frame.samples_per_channel() == 80 || frame.samples_per_channel() == 160 ||
frame.samples_per_channel() == 320 || frame.samples_per_channel() == 480)
<< "`frame` does not look like a 10 ms frame for an APM supported sample "
"rate";
// Compute the input level used to select the desired gain.
RTC_DCHECK_GT(info.headroom_db, 0.0f);
const float input_level_dbfs = info.speech_level_dbfs + info.headroom_db;
const float target_gain_db = LimitGainByLowConfidence(
LimitGainByNoise(ComputeGainDb(input_level_dbfs, config_),
info.noise_rms_dbfs, config_.max_output_noise_level_dbfs,
*apm_data_dumper_),
last_gain_db_, info.limiter_envelope_dbfs, info.speech_level_reliable);
// Forbid increasing the gain until enough adjacent speech frames are
// observed.
bool first_confident_speech_frame = false;
if (info.speech_probability < kVadConfidenceThreshold) {
frames_to_gain_increase_allowed_ = adjacent_speech_frames_threshold_;
} else if (frames_to_gain_increase_allowed_ > 0) {
frames_to_gain_increase_allowed_--;
first_confident_speech_frame = frames_to_gain_increase_allowed_ == 0;
}
apm_data_dumper_->DumpRaw(
"agc2_adaptive_gain_applier_frames_to_gain_increase_allowed",
frames_to_gain_increase_allowed_);
const bool gain_increase_allowed = frames_to_gain_increase_allowed_ == 0;
float max_gain_increase_db = max_gain_change_db_per_10ms_;
if (first_confident_speech_frame) {
// No gain increase happened while waiting for a long enough speech
// sequence. Therefore, temporarily allow a faster gain increase.
RTC_DCHECK(gain_increase_allowed);
max_gain_increase_db *= adjacent_speech_frames_threshold_;
}
const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb(
target_gain_db, last_gain_db_, gain_increase_allowed,
/*max_gain_decrease_db=*/max_gain_change_db_per_10ms_,
max_gain_increase_db);
apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_want_to_change_by_db",
target_gain_db - last_gain_db_);
apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_will_change_by_db",
gain_change_this_frame_db);
// Optimization: avoid calling math functions if gain does not
// change.
if (gain_change_this_frame_db != 0.f) {
gain_applier_.SetGainFactor(
DbToRatio(last_gain_db_ + gain_change_this_frame_db));
}
gain_applier_.ApplyGain(frame);
// Remember that the gain has changed for the next iteration.
last_gain_db_ = last_gain_db_ + gain_change_this_frame_db;
apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_applied_gain_db",
last_gain_db_);
// Log every 10 seconds.
calls_since_last_gain_log_++;
if (calls_since_last_gain_log_ == 1000) {
calls_since_last_gain_log_ = 0;
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedSpeechLevel",
-info.speech_level_dbfs, 0, 100, 101);
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedNoiseLevel",
-info.noise_rms_dbfs, 0, 100, 101);
RTC_HISTOGRAM_COUNTS_LINEAR(
"WebRTC.Audio.Agc2.Headroom", info.headroom_db, kHeadroomHistogramMin,
kHeadroomHistogramMax,
kHeadroomHistogramMax - kHeadroomHistogramMin + 1);
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.DigitalGainApplied",
last_gain_db_, 0, kGainDbHistogramMax,
kGainDbHistogramMax + 1);
RTC_LOG(LS_INFO) << "AGC2 adaptive digital"
<< " | speech_dbfs: " << info.speech_level_dbfs
<< " | noise_dbfs: " << info.noise_rms_dbfs
<< " | headroom_db: " << info.headroom_db
<< " | gain_db: " << last_gain_db_;
}
}
} // namespace webrtc

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_CONTROLLER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_CONTROLLER_H_
#include <vector>
#include "modules/audio_processing/agc2/gain_applier.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/include/audio_processing.h"
namespace webrtc {
class ApmDataDumper;
// Selects the target digital gain, decides when and how quickly to adapt to the
// target and applies the current gain to 10 ms frames.
class AdaptiveDigitalGainController {
public:
// Information about a frame to process.
struct FrameInfo {
float speech_probability; // Probability of speech in the [0, 1] range.
float speech_level_dbfs; // Estimated speech level (dBFS).
bool speech_level_reliable; // True with reliable speech level estimation.
float noise_rms_dbfs; // Estimated noise RMS level (dBFS).
float headroom_db; // Headroom (dB).
// TODO(bugs.webrtc.org/7494): Remove `limiter_envelope_dbfs`.
float limiter_envelope_dbfs; // Envelope level from the limiter (dBFS).
};
AdaptiveDigitalGainController(
ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
int adjacent_speech_frames_threshold);
AdaptiveDigitalGainController(const AdaptiveDigitalGainController&) = delete;
AdaptiveDigitalGainController& operator=(
const AdaptiveDigitalGainController&) = delete;
// Analyzes `info`, updates the digital gain and applies it to a 10 ms
// `frame`. Supports any sample rate supported by APM.
void Process(const FrameInfo& info, AudioFrameView<float> frame);
private:
ApmDataDumper* const apm_data_dumper_;
GainApplier gain_applier_;
const AudioProcessing::Config::GainController2::AdaptiveDigital config_;
const int adjacent_speech_frames_threshold_;
const float max_gain_change_db_per_10ms_;
int calls_since_last_gain_log_;
int frames_to_gain_increase_allowed_;
float last_gain_db_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_CONTROLLER_H_

View File

@ -1,198 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
using LevelEstimatorType =
AudioProcessing::Config::GainController2::LevelEstimator;
// Combines a level estimation with the saturation protector margins.
float ComputeLevelEstimateDbfs(float level_estimate_dbfs,
float saturation_margin_db,
float extra_saturation_margin_db) {
return rtc::SafeClamp<float>(
level_estimate_dbfs + saturation_margin_db + extra_saturation_margin_db,
-90.f, 30.f);
}
// Returns the level of given type from `vad_level`.
float GetLevel(const VadLevelAnalyzer::Result& vad_level,
LevelEstimatorType type) {
switch (type) {
case LevelEstimatorType::kRms:
return vad_level.rms_dbfs;
break;
case LevelEstimatorType::kPeak:
return vad_level.peak_dbfs;
break;
}
}
} // namespace
bool AdaptiveModeLevelEstimator::LevelEstimatorState::operator==(
const AdaptiveModeLevelEstimator::LevelEstimatorState& b) const {
return time_to_full_buffer_ms == b.time_to_full_buffer_ms &&
level_dbfs.numerator == b.level_dbfs.numerator &&
level_dbfs.denominator == b.level_dbfs.denominator &&
saturation_protector == b.saturation_protector;
}
float AdaptiveModeLevelEstimator::LevelEstimatorState::Ratio::GetRatio() const {
RTC_DCHECK_NE(denominator, 0.f);
return numerator / denominator;
}
AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
ApmDataDumper* apm_data_dumper)
: AdaptiveModeLevelEstimator(
apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator::kRms,
kDefaultLevelEstimatorAdjacentSpeechFramesThreshold,
kDefaultInitialSaturationMarginDb,
kDefaultExtraSaturationMarginDb) {}
AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
ApmDataDumper* apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator level_estimator,
int adjacent_speech_frames_threshold,
float initial_saturation_margin_db,
float extra_saturation_margin_db)
: apm_data_dumper_(apm_data_dumper),
level_estimator_type_(level_estimator),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
initial_saturation_margin_db_(initial_saturation_margin_db),
extra_saturation_margin_db_(extra_saturation_margin_db),
level_dbfs_(ComputeLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs,
initial_saturation_margin_db_,
extra_saturation_margin_db_)) {
RTC_DCHECK(apm_data_dumper_);
RTC_DCHECK_GE(adjacent_speech_frames_threshold_, 1);
Reset();
}
void AdaptiveModeLevelEstimator::Update(
const VadLevelAnalyzer::Result& vad_level) {
RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f);
RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f);
RTC_DCHECK_GT(vad_level.peak_dbfs, -150.f);
RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f);
RTC_DCHECK_GE(vad_level.speech_probability, 0.f);
RTC_DCHECK_LE(vad_level.speech_probability, 1.f);
DumpDebugData();
if (vad_level.speech_probability < kVadConfidenceThreshold) {
// Not a speech frame.
if (adjacent_speech_frames_threshold_ > 1) {
// When two or more adjacent speech frames are required in order to update
// the state, we need to decide whether to discard or confirm the updates
// based on the speech sequence length.
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// First non-speech frame after a long enough sequence of speech frames.
// Update the reliable state.
reliable_state_ = preliminary_state_;
} else if (num_adjacent_speech_frames_ > 0) {
// First non-speech frame after a too short sequence of speech frames.
// Reset to the last reliable state.
preliminary_state_ = reliable_state_;
}
}
num_adjacent_speech_frames_ = 0;
return;
}
// Speech frame observed.
num_adjacent_speech_frames_++;
// Update preliminary level estimate.
RTC_DCHECK_GE(preliminary_state_.time_to_full_buffer_ms, 0);
const bool buffer_is_full = preliminary_state_.time_to_full_buffer_ms == 0;
if (!buffer_is_full) {
preliminary_state_.time_to_full_buffer_ms -= kFrameDurationMs;
}
// Weighted average of levels with speech probability as weight.
RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f;
preliminary_state_.level_dbfs.numerator =
preliminary_state_.level_dbfs.numerator * leak_factor +
GetLevel(vad_level, level_estimator_type_) * vad_level.speech_probability;
preliminary_state_.level_dbfs.denominator =
preliminary_state_.level_dbfs.denominator * leak_factor +
vad_level.speech_probability;
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
UpdateSaturationProtectorState(vad_level.peak_dbfs, level_dbfs,
preliminary_state_.saturation_protector);
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// `preliminary_state_` is now reliable. Update the last level estimation.
level_dbfs_ = ComputeLevelEstimateDbfs(
level_dbfs, preliminary_state_.saturation_protector.margin_db,
extra_saturation_margin_db_);
}
}
bool AdaptiveModeLevelEstimator::IsConfident() const {
if (adjacent_speech_frames_threshold_ == 1) {
// Ignore `reliable_state_` when a single frame is enough to update the
// level estimate (because it is not used).
return preliminary_state_.time_to_full_buffer_ms == 0;
}
// Once confident, it remains confident.
RTC_DCHECK(reliable_state_.time_to_full_buffer_ms != 0 ||
preliminary_state_.time_to_full_buffer_ms == 0);
// During the first long enough speech sequence, `reliable_state_` must be
// ignored since `preliminary_state_` is used.
return reliable_state_.time_to_full_buffer_ms == 0 ||
(num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_ &&
preliminary_state_.time_to_full_buffer_ms == 0);
}
void AdaptiveModeLevelEstimator::Reset() {
ResetLevelEstimatorState(preliminary_state_);
ResetLevelEstimatorState(reliable_state_);
level_dbfs_ = ComputeLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs,
initial_saturation_margin_db_,
extra_saturation_margin_db_);
num_adjacent_speech_frames_ = 0;
}
void AdaptiveModeLevelEstimator::ResetLevelEstimatorState(
LevelEstimatorState& state) const {
state.time_to_full_buffer_ms = kFullBufferSizeMs;
state.level_dbfs.numerator = 0.f;
state.level_dbfs.denominator = 0.f;
ResetSaturationProtectorState(initial_saturation_margin_db_,
state.saturation_protector);
}
void AdaptiveModeLevelEstimator::DumpDebugData() const {
apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs", level_dbfs_);
apm_data_dumper_->DumpRaw("agc2_adaptive_num_adjacent_speech_frames_",
num_adjacent_speech_frames_);
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_num",
preliminary_state_.level_dbfs.numerator);
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_den",
preliminary_state_.level_dbfs.denominator);
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_saturation_margin_db",
preliminary_state_.saturation_protector.margin_db);
}
} // namespace webrtc

View File

@ -1,65 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.h"
#include <cmath>
#include <vector>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
AdaptiveModeLevelEstimatorAgc::AdaptiveModeLevelEstimatorAgc(
ApmDataDumper* apm_data_dumper)
: level_estimator_(apm_data_dumper) {
set_target_level_dbfs(kDefaultAgc2LevelHeadroomDbfs);
}
// |audio| must be mono; in a multi-channel stream, provide the first (usually
// left) channel.
void AdaptiveModeLevelEstimatorAgc::Process(const int16_t* audio,
size_t length,
int sample_rate_hz) {
std::vector<float> float_audio_frame(audio, audio + length);
const float* const first_channel = &float_audio_frame[0];
AudioFrameView<const float> frame_view(&first_channel, 1 /* num channels */,
length);
const auto vad_prob = agc2_vad_.AnalyzeFrame(frame_view);
latest_voice_probability_ = vad_prob.speech_probability;
if (latest_voice_probability_ > kVadConfidenceThreshold) {
time_in_ms_since_last_estimate_ += kFrameDurationMs;
}
level_estimator_.Update(vad_prob);
}
// Retrieves the difference between the target RMS level and the current
// signal RMS level in dB. Returns true if an update is available and false
// otherwise, in which case |error| should be ignored and no action taken.
bool AdaptiveModeLevelEstimatorAgc::GetRmsErrorDb(int* error) {
if (time_in_ms_since_last_estimate_ <= kTimeUntilConfidentMs) {
return false;
}
*error =
std::floor(target_level_dbfs() - level_estimator_.level_dbfs() + 0.5f);
time_in_ms_since_last_estimate_ = 0;
return true;
}
void AdaptiveModeLevelEstimatorAgc::Reset() {
level_estimator_.Reset();
}
float AdaptiveModeLevelEstimatorAgc::voice_probability() const {
return latest_voice_probability_;
}
} // namespace webrtc

View File

@ -1,51 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_
#include <stddef.h>
#include <stdint.h>
#include "modules/audio_processing/agc/agc.h"
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
namespace webrtc {
class AdaptiveModeLevelEstimatorAgc : public Agc {
public:
explicit AdaptiveModeLevelEstimatorAgc(ApmDataDumper* apm_data_dumper);
// |audio| must be mono; in a multi-channel stream, provide the first (usually
// left) channel.
void Process(const int16_t* audio,
size_t length,
int sample_rate_hz) override;
// Retrieves the difference between the target RMS level and the current
// signal RMS level in dB. Returns true if an update is available and false
// otherwise, in which case |error| should be ignored and no action taken.
bool GetRmsErrorDb(int* error) override;
void Reset() override;
float voice_probability() const override;
private:
static constexpr int kTimeUntilConfidentMs = 700;
static constexpr int kDefaultAgc2LevelHeadroomDbfs = -1;
int32_t time_in_ms_since_last_estimate_ = 0;
AdaptiveModeLevelEstimator level_estimator_;
VadLevelAnalyzer agc2_vad_;
float latest_voice_probability_ = 0.f;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_

View File

@ -11,74 +11,50 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
#include <stddef.h>
namespace webrtc {
constexpr float kMinFloatS16Value = -32768.f;
constexpr float kMaxFloatS16Value = 32767.f;
constexpr float kMinFloatS16Value = -32768.0f;
constexpr float kMaxFloatS16Value = 32767.0f;
constexpr float kMaxAbsFloatS16Value = 32768.0f;
constexpr size_t kFrameDurationMs = 10;
constexpr size_t kSubFramesInFrame = 20;
constexpr size_t kMaximalNumberOfSamplesPerChannel = 480;
// Minimum audio level in dBFS scale for S16 samples.
constexpr float kMinLevelDbfs = -90.31f;
constexpr float kAttackFilterConstant = 0.f;
constexpr int kFrameDurationMs = 10;
constexpr int kSubFramesInFrame = 20;
constexpr int kMaximalNumberOfSamplesPerChannel = 480;
// Adaptive digital gain applier settings.
// Adaptive digital gain applier settings below.
constexpr float kHeadroomDbfs = 1.f;
constexpr float kMaxGainDb = 30.f;
constexpr float kInitialAdaptiveDigitalGainDb = 8.f;
// At what limiter levels should we start decreasing the adaptive digital gain.
constexpr float kLimiterThresholdForAgcGainDbfs = -kHeadroomDbfs;
constexpr float kLimiterThresholdForAgcGainDbfs = -1.0f;
// This is the threshold for speech. Speech frames are used for updating the
// speech level, measuring the amount of speech, and decide when to allow target
// gain reduction.
constexpr float kVadConfidenceThreshold = 0.9f;
// Number of milliseconds to wait to periodically reset the VAD.
constexpr int kVadResetPeriodMs = 1500;
// The amount of 'memory' of the Level Estimator. Decides leak factors.
constexpr size_t kFullBufferSizeMs = 1200;
constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs;
// Speech probability threshold to detect speech activity.
constexpr float kVadConfidenceThreshold = 0.95f;
constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
// Minimum number of adjacent speech frames having a sufficiently high speech
// probability to reliably detect speech activity.
constexpr int kAdjacentSpeechFramesThreshold = 12;
// Robust VAD probability and speech decisions.
constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f;
constexpr int kDefaultLevelEstimatorAdjacentSpeechFramesThreshold = 1;
// Number of milliseconds of speech frames to observe to make the estimator
// confident.
constexpr float kLevelEstimatorTimeToConfidenceMs = 400;
constexpr float kLevelEstimatorLeakFactor =
1.0f - 1.0f / kLevelEstimatorTimeToConfidenceMs;
// Saturation Protector settings.
constexpr float kDefaultInitialSaturationMarginDb = 20.f;
constexpr float kDefaultExtraSaturationMarginDb = 2.f;
constexpr size_t kPeakEnveloperSuperFrameLengthMs = 400;
static_assert(kFullBufferSizeMs % kPeakEnveloperSuperFrameLengthMs == 0,
"Full buffer size should be a multiple of super frame length for "
"optimal Saturation Protector performance.");
constexpr size_t kPeakEnveloperBufferSize =
kFullBufferSizeMs / kPeakEnveloperSuperFrameLengthMs + 1;
// This value is 10 ** (-1/20 * frame_size_ms / satproc_attack_ms),
// where satproc_attack_ms is 5000.
constexpr float kSaturationProtectorAttackConstant = 0.9988493699365052f;
// This value is 10 ** (-1/20 * frame_size_ms / satproc_decay_ms),
// where satproc_decay_ms is 1000.
constexpr float kSaturationProtectorDecayConstant = 0.9997697679981565f;
// This is computed from kDecayMs by
// 10 ** (-1/20 * subframe_duration / kDecayMs).
// |subframe_duration| is |kFrameDurationMs / kSubFramesInFrame|.
// kDecayMs is defined in agc2_testing_common.h
constexpr float kDecayFilterConstant = 0.9998848773724686f;
constexpr float kSaturationProtectorInitialHeadroomDb = 20.0f;
constexpr int kSaturationProtectorBufferSize = 4;
// Number of interpolation points for each region of the limiter.
// These values have been tuned to limit the interpolated gain curve error given
// the limiter parameters and allowing a maximum error of +/- 32768^-1.
constexpr size_t kInterpolatedGainCurveKneePoints = 22;
constexpr size_t kInterpolatedGainCurveBeyondKneePoints = 10;
constexpr size_t kInterpolatedGainCurveTotalPoints =
constexpr int kInterpolatedGainCurveKneePoints = 22;
constexpr int kInterpolatedGainCurveBeyondKneePoints = 10;
constexpr int kInterpolatedGainCurveTotalPoints =
kInterpolatedGainCurveKneePoints + kInterpolatedGainCurveBeyondKneePoints;
} // namespace webrtc

View File

@ -10,24 +10,85 @@
#include "modules/audio_processing/agc2/agc2_testing_common.h"
#include <math.h>
#include "rtc_base/checks.h"
namespace webrtc {
namespace test {
std::vector<double> LinSpace(const double l,
const double r,
size_t num_points) {
RTC_CHECK(num_points >= 2);
std::vector<double> LinSpace(double l, double r, int num_points) {
RTC_CHECK_GE(num_points, 2);
std::vector<double> points(num_points);
const double step = (r - l) / (num_points - 1.0);
points[0] = l;
for (size_t i = 1; i < num_points - 1; i++) {
for (int i = 1; i < num_points - 1; i++) {
points[i] = static_cast<double>(l) + i * step;
}
points[num_points - 1] = r;
return points;
}
WhiteNoiseGenerator::WhiteNoiseGenerator(int min_amplitude, int max_amplitude)
: rand_gen_(42),
min_amplitude_(min_amplitude),
max_amplitude_(max_amplitude) {
RTC_DCHECK_LT(min_amplitude_, max_amplitude_);
RTC_DCHECK_LE(kMinS16, min_amplitude_);
RTC_DCHECK_LE(min_amplitude_, kMaxS16);
RTC_DCHECK_LE(kMinS16, max_amplitude_);
RTC_DCHECK_LE(max_amplitude_, kMaxS16);
}
float WhiteNoiseGenerator::operator()() {
return static_cast<float>(rand_gen_.Rand(min_amplitude_, max_amplitude_));
}
SineGenerator::SineGenerator(float amplitude,
float frequency_hz,
int sample_rate_hz)
: amplitude_(amplitude),
frequency_hz_(frequency_hz),
sample_rate_hz_(sample_rate_hz),
x_radians_(0.0f) {
RTC_DCHECK_GT(amplitude_, 0);
RTC_DCHECK_LE(amplitude_, kMaxS16);
}
float SineGenerator::operator()() {
constexpr float kPi = 3.1415926536f;
x_radians_ += frequency_hz_ / sample_rate_hz_ * 2 * kPi;
if (x_radians_ >= 2 * kPi) {
x_radians_ -= 2 * kPi;
}
// Use sinf instead of std::sinf for libstdc++ compatibility.
return amplitude_ * sinf(x_radians_);
}
PulseGenerator::PulseGenerator(float pulse_amplitude,
float no_pulse_amplitude,
float frequency_hz,
int sample_rate_hz)
: pulse_amplitude_(pulse_amplitude),
no_pulse_amplitude_(no_pulse_amplitude),
samples_period_(
static_cast<int>(static_cast<float>(sample_rate_hz) / frequency_hz)),
sample_counter_(0) {
RTC_DCHECK_GE(pulse_amplitude_, kMinS16);
RTC_DCHECK_LE(pulse_amplitude_, kMaxS16);
RTC_DCHECK_GT(no_pulse_amplitude_, kMinS16);
RTC_DCHECK_LE(no_pulse_amplitude_, kMaxS16);
RTC_DCHECK_GT(sample_rate_hz, frequency_hz);
}
float PulseGenerator::operator()() {
sample_counter_++;
if (sample_counter_ >= samples_period_) {
sample_counter_ -= samples_period_;
}
return static_cast<float>(sample_counter_ == 0 ? pulse_amplitude_
: no_pulse_amplitude_);
}
} // namespace test
} // namespace webrtc

View File

@ -11,65 +11,69 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
#include <math.h>
#include <limits>
#include <vector>
#include "rtc_base/checks.h"
#include "rtc_base/random.h"
namespace webrtc {
namespace test {
constexpr float kMinS16 =
static_cast<float>(std::numeric_limits<int16_t>::min());
constexpr float kMaxS16 =
static_cast<float>(std::numeric_limits<int16_t>::max());
// Level Estimator test parameters.
constexpr float kDecayMs = 500.f;
constexpr float kDecayMs = 20.0f;
// Limiter parameters.
constexpr float kLimiterMaxInputLevelDbFs = 1.f;
constexpr float kLimiterKneeSmoothnessDb = 1.f;
constexpr float kLimiterCompressionRatio = 5.f;
constexpr float kPi = 3.1415926536f;
std::vector<double> LinSpace(const double l, const double r, size_t num_points);
// Returns evenly spaced `num_points` numbers over a specified interval [l, r].
std::vector<double> LinSpace(double l, double r, int num_points);
class SineGenerator {
// Generates white noise.
class WhiteNoiseGenerator {
public:
SineGenerator(float frequency, int rate)
: frequency_(frequency), rate_(rate) {}
float operator()() {
x_radians_ += frequency_ / rate_ * 2 * kPi;
if (x_radians_ > 2 * kPi) {
x_radians_ -= 2 * kPi;
}
return 1000.f * sinf(x_radians_);
}
WhiteNoiseGenerator(int min_amplitude, int max_amplitude);
float operator()();
private:
float frequency_;
int rate_;
float x_radians_ = 0.f;
Random rand_gen_;
const int min_amplitude_;
const int max_amplitude_;
};
class PulseGenerator {
// Generates a sine function.
class SineGenerator {
public:
PulseGenerator(float frequency, int rate)
: samples_period_(
static_cast<int>(static_cast<float>(rate) / frequency)) {
RTC_DCHECK_GT(rate, frequency);
}
float operator()() {
sample_counter_++;
if (sample_counter_ >= samples_period_) {
sample_counter_ -= samples_period_;
}
return static_cast<float>(
sample_counter_ == 0 ? std::numeric_limits<int16_t>::max() : 10.f);
}
SineGenerator(float amplitude, float frequency_hz, int sample_rate_hz);
float operator()();
private:
int samples_period_;
int sample_counter_ = 0;
const float amplitude_;
const float frequency_hz_;
const int sample_rate_hz_;
float x_radians_;
};
// Generates periodic pulses.
class PulseGenerator {
public:
PulseGenerator(float pulse_amplitude,
float no_pulse_amplitude,
float frequency_hz,
int sample_rate_hz);
float operator()();
private:
const float pulse_amplitude_;
const float no_pulse_amplitude_;
const int samples_period_;
int sample_counter_;
};
} // namespace test

View File

@ -10,27 +10,51 @@
#include "modules/audio_processing/agc2/biquad_filter.h"
#include <stddef.h>
#include "rtc_base/arraysize.h"
namespace webrtc {
// Transposed direct form I implementation of a bi-quad filter applied to an
// input signal |x| to produce an output signal |y|.
BiQuadFilter::BiQuadFilter(const Config& config)
: config_(config), state_({}) {}
BiQuadFilter::~BiQuadFilter() = default;
void BiQuadFilter::SetConfig(const Config& config) {
config_ = config;
state_ = {};
}
void BiQuadFilter::Reset() {
state_ = {};
}
void BiQuadFilter::Process(rtc::ArrayView<const float> x,
rtc::ArrayView<float> y) {
for (size_t k = 0; k < x.size(); ++k) {
// Use temporary variable for x[k] to allow in-place function call
// (that x and y refer to the same array).
RTC_DCHECK_EQ(x.size(), y.size());
const float config_a0 = config_.a[0];
const float config_a1 = config_.a[1];
const float config_b0 = config_.b[0];
const float config_b1 = config_.b[1];
const float config_b2 = config_.b[2];
float state_a0 = state_.a[0];
float state_a1 = state_.a[1];
float state_b0 = state_.b[0];
float state_b1 = state_.b[1];
for (size_t k = 0, x_size = x.size(); k < x_size; ++k) {
// Use a temporary variable for `x[k]` to allow in-place processing.
const float tmp = x[k];
y[k] = coefficients_.b[0] * tmp + coefficients_.b[1] * biquad_state_.b[0] +
coefficients_.b[2] * biquad_state_.b[1] -
coefficients_.a[0] * biquad_state_.a[0] -
coefficients_.a[1] * biquad_state_.a[1];
biquad_state_.b[1] = biquad_state_.b[0];
biquad_state_.b[0] = tmp;
biquad_state_.a[1] = biquad_state_.a[0];
biquad_state_.a[0] = y[k];
float y_k = config_b0 * tmp + config_b1 * state_b0 + config_b2 * state_b1 -
config_a0 * state_a0 - config_a1 * state_a1;
state_b1 = state_b0;
state_b0 = tmp;
state_a1 = state_a0;
state_a0 = y_k;
y[k] = y_k;
}
state_.a[0] = state_a0;
state_.a[1] = state_a1;
state_.b[0] = state_b0;
state_.b[1] = state_b1;
}
} // namespace webrtc

View File

@ -11,54 +11,44 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
#include <algorithm>
#include "api/array_view.h"
#include "rtc_base/arraysize.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
// Transposed direct form I implementation of a bi-quad filter.
// b[0] + b[1] • z^(-1) + b[2] • z^(-2)
// H(z) = ------------------------------------
// 1 + a[1] • z^(-1) + a[2] • z^(-2)
class BiQuadFilter {
public:
// Normalized filter coefficients.
// b_0 + b_1 • z^(-1) + b_2 • z^(-2)
// H(z) = ---------------------------------
// 1 + a_1 • z^(-1) + a_2 • z^(-2)
struct BiQuadCoefficients {
float b[3];
float a[2];
// Computed as `[b, a] = scipy.signal.butter(N=2, Wn, btype)`.
struct Config {
float b[3]; // b[0], b[1], b[2].
float a[2]; // a[1], a[2].
};
BiQuadFilter() = default;
explicit BiQuadFilter(const Config& config);
BiQuadFilter(const BiQuadFilter&) = delete;
BiQuadFilter& operator=(const BiQuadFilter&) = delete;
~BiQuadFilter();
void Initialize(const BiQuadCoefficients& coefficients) {
coefficients_ = coefficients;
}
// Sets the filter configuration and resets the internal state.
void SetConfig(const Config& config);
void Reset() { biquad_state_.Reset(); }
// Zeroes the filter state.
void Reset();
// Produces a filtered output y of the input x. Both x and y need to
// have the same length. In-place modification is allowed.
// Filters `x` and writes the output in `y`, which must have the same length
// of `x`. In-place processing is supported.
void Process(rtc::ArrayView<const float> x, rtc::ArrayView<float> y);
private:
struct BiQuadState {
BiQuadState() { Reset(); }
void Reset() {
std::fill(b, b + arraysize(b), 0.f);
std::fill(a, a + arraysize(a), 0.f);
}
Config config_;
struct State {
float b[2];
float a[2];
};
BiQuadState biquad_state_;
BiQuadCoefficients coefficients_;
RTC_DISALLOW_COPY_AND_ASSIGN(BiQuadFilter);
} state_;
};
} // namespace webrtc

View File

@ -0,0 +1,384 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/clipping_predictor.h"
#include <algorithm>
#include <memory>
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/clipping_predictor_level_buffer.h"
#include "modules/audio_processing/agc2/gain_map_internal.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
constexpr int kClippingPredictorMaxGainChange = 15;
// Returns an input volume in the [`min_input_volume`, `max_input_volume`] range
// that reduces `gain_error_db`, which is a gain error estimated when
// `input_volume` was applied, according to a fixed gain map.
int ComputeVolumeUpdate(int gain_error_db,
int input_volume,
int min_input_volume,
int max_input_volume) {
RTC_DCHECK_GE(input_volume, 0);
RTC_DCHECK_LE(input_volume, max_input_volume);
if (gain_error_db == 0) {
return input_volume;
}
int new_volume = input_volume;
if (gain_error_db > 0) {
while (kGainMap[new_volume] - kGainMap[input_volume] < gain_error_db &&
new_volume < max_input_volume) {
++new_volume;
}
} else {
while (kGainMap[new_volume] - kGainMap[input_volume] > gain_error_db &&
new_volume > min_input_volume) {
--new_volume;
}
}
return new_volume;
}
float ComputeCrestFactor(const ClippingPredictorLevelBuffer::Level& level) {
const float crest_factor =
FloatS16ToDbfs(level.max) - FloatS16ToDbfs(std::sqrt(level.average));
return crest_factor;
}
// Crest factor-based clipping prediction and clipped level step estimation.
class ClippingEventPredictor : public ClippingPredictor {
public:
// ClippingEventPredictor with `num_channels` channels (limited to values
// higher than zero); window size `window_length` and reference window size
// `reference_window_length` (both referring to the number of frames in the
// respective sliding windows and limited to values higher than zero);
// reference window delay `reference_window_delay` (delay in frames, limited
// to values zero and higher with an additional requirement of
// `window_length` < `reference_window_length` + reference_window_delay`);
// and an estimation peak threshold `clipping_threshold` and a crest factor
// drop threshold `crest_factor_margin` (both in dB).
ClippingEventPredictor(int num_channels,
int window_length,
int reference_window_length,
int reference_window_delay,
float clipping_threshold,
float crest_factor_margin)
: window_length_(window_length),
reference_window_length_(reference_window_length),
reference_window_delay_(reference_window_delay),
clipping_threshold_(clipping_threshold),
crest_factor_margin_(crest_factor_margin) {
RTC_DCHECK_GT(num_channels, 0);
RTC_DCHECK_GT(window_length, 0);
RTC_DCHECK_GT(reference_window_length, 0);
RTC_DCHECK_GE(reference_window_delay, 0);
RTC_DCHECK_GT(reference_window_length + reference_window_delay,
window_length);
const int buffer_length = GetMinFramesProcessed();
RTC_DCHECK_GT(buffer_length, 0);
for (int i = 0; i < num_channels; ++i) {
ch_buffers_.push_back(
std::make_unique<ClippingPredictorLevelBuffer>(buffer_length));
}
}
ClippingEventPredictor(const ClippingEventPredictor&) = delete;
ClippingEventPredictor& operator=(const ClippingEventPredictor&) = delete;
~ClippingEventPredictor() {}
void Reset() {
const int num_channels = ch_buffers_.size();
for (int i = 0; i < num_channels; ++i) {
ch_buffers_[i]->Reset();
}
}
// Analyzes a frame of audio and stores the framewise metrics in
// `ch_buffers_`.
void Analyze(const AudioFrameView<const float>& frame) {
const int num_channels = frame.num_channels();
RTC_DCHECK_EQ(num_channels, ch_buffers_.size());
const int samples_per_channel = frame.samples_per_channel();
RTC_DCHECK_GT(samples_per_channel, 0);
for (int channel = 0; channel < num_channels; ++channel) {
float sum_squares = 0.0f;
float peak = 0.0f;
for (const auto& sample : frame.channel(channel)) {
sum_squares += sample * sample;
peak = std::max(std::fabs(sample), peak);
}
ch_buffers_[channel]->Push(
{sum_squares / static_cast<float>(samples_per_channel), peak});
}
}
// Estimates the analog gain adjustment for channel `channel` using a
// sliding window over the frame-wise metrics in `ch_buffers_`. Returns an
// estimate for the clipped level step equal to `default_clipped_level_step_`
// if at least `GetMinFramesProcessed()` frames have been processed since the
// last reset and a clipping event is predicted. `level`, `min_mic_level`, and
// `max_mic_level` are limited to [0, 255] and `default_step` to [1, 255].
absl::optional<int> EstimateClippedLevelStep(int channel,
int level,
int default_step,
int min_mic_level,
int max_mic_level) const {
RTC_CHECK_GE(channel, 0);
RTC_CHECK_LT(channel, ch_buffers_.size());
RTC_DCHECK_GE(level, 0);
RTC_DCHECK_LE(level, 255);
RTC_DCHECK_GT(default_step, 0);
RTC_DCHECK_LE(default_step, 255);
RTC_DCHECK_GE(min_mic_level, 0);
RTC_DCHECK_LE(min_mic_level, 255);
RTC_DCHECK_GE(max_mic_level, 0);
RTC_DCHECK_LE(max_mic_level, 255);
if (level <= min_mic_level) {
return absl::nullopt;
}
if (PredictClippingEvent(channel)) {
const int new_level =
rtc::SafeClamp(level - default_step, min_mic_level, max_mic_level);
const int step = level - new_level;
if (step > 0) {
return step;
}
}
return absl::nullopt;
}
private:
int GetMinFramesProcessed() const {
return reference_window_delay_ + reference_window_length_;
}
// Predicts clipping events based on the processed audio frames. Returns
// true if a clipping event is likely.
bool PredictClippingEvent(int channel) const {
const auto metrics =
ch_buffers_[channel]->ComputePartialMetrics(0, window_length_);
if (!metrics.has_value() ||
!(FloatS16ToDbfs(metrics.value().max) > clipping_threshold_)) {
return false;
}
const auto reference_metrics = ch_buffers_[channel]->ComputePartialMetrics(
reference_window_delay_, reference_window_length_);
if (!reference_metrics.has_value()) {
return false;
}
const float crest_factor = ComputeCrestFactor(metrics.value());
const float reference_crest_factor =
ComputeCrestFactor(reference_metrics.value());
if (crest_factor < reference_crest_factor - crest_factor_margin_) {
return true;
}
return false;
}
std::vector<std::unique_ptr<ClippingPredictorLevelBuffer>> ch_buffers_;
const int window_length_;
const int reference_window_length_;
const int reference_window_delay_;
const float clipping_threshold_;
const float crest_factor_margin_;
};
// Performs crest factor-based clipping peak prediction.
class ClippingPeakPredictor : public ClippingPredictor {
public:
// Ctor. ClippingPeakPredictor with `num_channels` channels (limited to values
// higher than zero); window size `window_length` and reference window size
// `reference_window_length` (both referring to the number of frames in the
// respective sliding windows and limited to values higher than zero);
// reference window delay `reference_window_delay` (delay in frames, limited
// to values zero and higher with an additional requirement of
// `window_length` < `reference_window_length` + reference_window_delay`);
// and a clipping prediction threshold `clipping_threshold` (in dB). Adaptive
// clipped level step estimation is used if `adaptive_step_estimation` is
// true.
explicit ClippingPeakPredictor(int num_channels,
int window_length,
int reference_window_length,
int reference_window_delay,
int clipping_threshold,
bool adaptive_step_estimation)
: window_length_(window_length),
reference_window_length_(reference_window_length),
reference_window_delay_(reference_window_delay),
clipping_threshold_(clipping_threshold),
adaptive_step_estimation_(adaptive_step_estimation) {
RTC_DCHECK_GT(num_channels, 0);
RTC_DCHECK_GT(window_length, 0);
RTC_DCHECK_GT(reference_window_length, 0);
RTC_DCHECK_GE(reference_window_delay, 0);
RTC_DCHECK_GT(reference_window_length + reference_window_delay,
window_length);
const int buffer_length = GetMinFramesProcessed();
RTC_DCHECK_GT(buffer_length, 0);
for (int i = 0; i < num_channels; ++i) {
ch_buffers_.push_back(
std::make_unique<ClippingPredictorLevelBuffer>(buffer_length));
}
}
ClippingPeakPredictor(const ClippingPeakPredictor&) = delete;
ClippingPeakPredictor& operator=(const ClippingPeakPredictor&) = delete;
~ClippingPeakPredictor() {}
void Reset() {
const int num_channels = ch_buffers_.size();
for (int i = 0; i < num_channels; ++i) {
ch_buffers_[i]->Reset();
}
}
// Analyzes a frame of audio and stores the framewise metrics in
// `ch_buffers_`.
void Analyze(const AudioFrameView<const float>& frame) {
const int num_channels = frame.num_channels();
RTC_DCHECK_EQ(num_channels, ch_buffers_.size());
const int samples_per_channel = frame.samples_per_channel();
RTC_DCHECK_GT(samples_per_channel, 0);
for (int channel = 0; channel < num_channels; ++channel) {
float sum_squares = 0.0f;
float peak = 0.0f;
for (const auto& sample : frame.channel(channel)) {
sum_squares += sample * sample;
peak = std::max(std::fabs(sample), peak);
}
ch_buffers_[channel]->Push(
{sum_squares / static_cast<float>(samples_per_channel), peak});
}
}
// Estimates the analog gain adjustment for channel `channel` using a
// sliding window over the frame-wise metrics in `ch_buffers_`. Returns an
// estimate for the clipped level step (equal to
// `default_clipped_level_step_` if `adaptive_estimation_` is false) if at
// least `GetMinFramesProcessed()` frames have been processed since the last
// reset and a clipping event is predicted. `level`, `min_mic_level`, and
// `max_mic_level` are limited to [0, 255] and `default_step` to [1, 255].
absl::optional<int> EstimateClippedLevelStep(int channel,
int level,
int default_step,
int min_mic_level,
int max_mic_level) const {
RTC_DCHECK_GE(channel, 0);
RTC_DCHECK_LT(channel, ch_buffers_.size());
RTC_DCHECK_GE(level, 0);
RTC_DCHECK_LE(level, 255);
RTC_DCHECK_GT(default_step, 0);
RTC_DCHECK_LE(default_step, 255);
RTC_DCHECK_GE(min_mic_level, 0);
RTC_DCHECK_LE(min_mic_level, 255);
RTC_DCHECK_GE(max_mic_level, 0);
RTC_DCHECK_LE(max_mic_level, 255);
if (level <= min_mic_level) {
return absl::nullopt;
}
absl::optional<float> estimate_db = EstimatePeakValue(channel);
if (estimate_db.has_value() && estimate_db.value() > clipping_threshold_) {
int step = 0;
if (!adaptive_step_estimation_) {
step = default_step;
} else {
const int estimated_gain_change =
rtc::SafeClamp(-static_cast<int>(std::ceil(estimate_db.value())),
-kClippingPredictorMaxGainChange, 0);
step =
std::max(level - ComputeVolumeUpdate(estimated_gain_change, level,
min_mic_level, max_mic_level),
default_step);
}
const int new_level =
rtc::SafeClamp(level - step, min_mic_level, max_mic_level);
if (level > new_level) {
return level - new_level;
}
}
return absl::nullopt;
}
private:
int GetMinFramesProcessed() {
return reference_window_delay_ + reference_window_length_;
}
// Predicts clipping sample peaks based on the processed audio frames.
// Returns the estimated peak value if clipping is predicted. Otherwise
// returns absl::nullopt.
absl::optional<float> EstimatePeakValue(int channel) const {
const auto reference_metrics = ch_buffers_[channel]->ComputePartialMetrics(
reference_window_delay_, reference_window_length_);
if (!reference_metrics.has_value()) {
return absl::nullopt;
}
const auto metrics =
ch_buffers_[channel]->ComputePartialMetrics(0, window_length_);
if (!metrics.has_value() ||
!(FloatS16ToDbfs(metrics.value().max) > clipping_threshold_)) {
return absl::nullopt;
}
const float reference_crest_factor =
ComputeCrestFactor(reference_metrics.value());
const float& mean_squares = metrics.value().average;
const float projected_peak =
reference_crest_factor + FloatS16ToDbfs(std::sqrt(mean_squares));
return projected_peak;
}
std::vector<std::unique_ptr<ClippingPredictorLevelBuffer>> ch_buffers_;
const int window_length_;
const int reference_window_length_;
const int reference_window_delay_;
const int clipping_threshold_;
const bool adaptive_step_estimation_;
};
} // namespace
std::unique_ptr<ClippingPredictor> CreateClippingPredictor(
int num_channels,
const AudioProcessing::Config::GainController1::AnalogGainController::
ClippingPredictor& config) {
if (!config.enabled) {
RTC_LOG(LS_INFO) << "[AGC2] Clipping prediction disabled.";
return nullptr;
}
RTC_LOG(LS_INFO) << "[AGC2] Clipping prediction enabled.";
using ClippingPredictorMode = AudioProcessing::Config::GainController1::
AnalogGainController::ClippingPredictor::Mode;
switch (config.mode) {
case ClippingPredictorMode::kClippingEventPrediction:
return std::make_unique<ClippingEventPredictor>(
num_channels, config.window_length, config.reference_window_length,
config.reference_window_delay, config.clipping_threshold,
config.crest_factor_margin);
case ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction:
return std::make_unique<ClippingPeakPredictor>(
num_channels, config.window_length, config.reference_window_length,
config.reference_window_delay, config.clipping_threshold,
/*adaptive_step_estimation=*/true);
case ClippingPredictorMode::kFixedStepClippingPeakPrediction:
return std::make_unique<ClippingPeakPredictor>(
num_channels, config.window_length, config.reference_window_length,
config.reference_window_delay, config.clipping_threshold,
/*adaptive_step_estimation=*/false);
}
RTC_DCHECK_NOTREACHED();
}
} // namespace webrtc

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/include/audio_processing.h"
namespace webrtc {
// Frame-wise clipping prediction and clipped level step estimation. Analyzes
// 10 ms multi-channel frames and estimates an analog mic level decrease step
// to possibly avoid clipping when predicted. `Analyze()` and
// `EstimateClippedLevelStep()` can be called in any order.
class ClippingPredictor {
public:
virtual ~ClippingPredictor() = default;
virtual void Reset() = 0;
// Analyzes a 10 ms multi-channel audio frame.
virtual void Analyze(const AudioFrameView<const float>& frame) = 0;
// Predicts if clipping is going to occur for the specified `channel` in the
// near-future and, if so, it returns a recommended analog mic level decrease
// step. Returns absl::nullopt if clipping is not predicted.
// `level` is the current analog mic level, `default_step` is the amount the
// mic level is lowered by the analog controller with every clipping event and
// `min_mic_level` and `max_mic_level` is the range of allowed analog mic
// levels.
virtual absl::optional<int> EstimateClippedLevelStep(
int channel,
int level,
int default_step,
int min_mic_level,
int max_mic_level) const = 0;
};
// Creates a ClippingPredictor based on the provided `config`. When enabled,
// the following must hold for `config`:
// `window_length < reference_window_length + reference_window_delay`.
// Returns `nullptr` if `config.enabled` is false.
std::unique_ptr<ClippingPredictor> CreateClippingPredictor(
int num_channels,
const AudioProcessing::Config::GainController1::AnalogGainController::
ClippingPredictor& config);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_H_

View File

@ -0,0 +1,77 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/clipping_predictor_level_buffer.h"
#include <algorithm>
#include <cmath>
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
namespace webrtc {
bool ClippingPredictorLevelBuffer::Level::operator==(const Level& level) const {
constexpr float kEpsilon = 1e-6f;
return std::fabs(average - level.average) < kEpsilon &&
std::fabs(max - level.max) < kEpsilon;
}
ClippingPredictorLevelBuffer::ClippingPredictorLevelBuffer(int capacity)
: tail_(-1), size_(0), data_(std::max(1, capacity)) {
if (capacity > kMaxCapacity) {
RTC_LOG(LS_WARNING) << "[agc]: ClippingPredictorLevelBuffer exceeds the "
<< "maximum allowed capacity. Capacity: " << capacity;
}
RTC_DCHECK(!data_.empty());
}
void ClippingPredictorLevelBuffer::Reset() {
tail_ = -1;
size_ = 0;
}
void ClippingPredictorLevelBuffer::Push(Level level) {
++tail_;
if (tail_ == Capacity()) {
tail_ = 0;
}
if (size_ < Capacity()) {
size_++;
}
data_[tail_] = level;
}
// TODO(bugs.webrtc.org/12774): Optimize partial computation for long buffers.
absl::optional<ClippingPredictorLevelBuffer::Level>
ClippingPredictorLevelBuffer::ComputePartialMetrics(int delay,
int num_items) const {
RTC_DCHECK_GE(delay, 0);
RTC_DCHECK_LT(delay, Capacity());
RTC_DCHECK_GT(num_items, 0);
RTC_DCHECK_LE(num_items, Capacity());
RTC_DCHECK_LE(delay + num_items, Capacity());
if (delay + num_items > Size()) {
return absl::nullopt;
}
float sum = 0.0f;
float max = 0.0f;
for (int i = 0; i < num_items && i < Size(); ++i) {
int idx = tail_ - delay - i;
if (idx < 0) {
idx += Capacity();
}
sum += data_[idx].average;
max = std::fmax(data_[idx].max, max);
}
return absl::optional<Level>({sum / static_cast<float>(num_items), max});
}
} // namespace webrtc

View File

@ -0,0 +1,71 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
namespace webrtc {
// A circular buffer to store frame-wise `Level` items for clipping prediction.
// The current implementation is not optimized for large buffer lengths.
class ClippingPredictorLevelBuffer {
public:
struct Level {
float average;
float max;
bool operator==(const Level& level) const;
};
// Recommended maximum capacity. It is possible to create a buffer with a
// larger capacity, but the implementation is not optimized for large values.
static constexpr int kMaxCapacity = 100;
// Ctor. Sets the buffer capacity to max(1, `capacity`) and logs a warning
// message if the capacity is greater than `kMaxCapacity`.
explicit ClippingPredictorLevelBuffer(int capacity);
~ClippingPredictorLevelBuffer() {}
ClippingPredictorLevelBuffer(const ClippingPredictorLevelBuffer&) = delete;
ClippingPredictorLevelBuffer& operator=(const ClippingPredictorLevelBuffer&) =
delete;
void Reset();
// Returns the current number of items stored in the buffer.
int Size() const { return size_; }
// Returns the capacity of the buffer.
int Capacity() const { return data_.size(); }
// Adds a `level` item into the circular buffer `data_`. Stores at most
// `Capacity()` items. If more items are pushed, the new item replaces the
// least recently pushed item.
void Push(Level level);
// If at least `num_items` + `delay` items have been pushed, returns the
// average and maximum value for the `num_items` most recently pushed items
// from `delay` to `delay` - `num_items` (a delay equal to zero corresponds
// to the most recently pushed item). The value of `delay` is limited to
// [0, N] and `num_items` to [1, M] where N + M is the capacity of the buffer.
absl::optional<Level> ComputePartialMetrics(int delay, int num_items) const;
private:
int tail_;
int size_;
std::vector<Level> data_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_

View File

@ -105,7 +105,7 @@ std::vector<double> SampleLimiterRegion(const LimiterDbGainCurve* limiter) {
const auto interval = q.top();
q.pop();
// Split |interval| and enqueue.
// Split `interval` and enqueue.
double x_split = (interval.x0 + interval.x1) / 2.0;
q.emplace(interval.x0, x_split,
LimiterUnderApproximationNegativeError(limiter, interval.x0,
@ -135,7 +135,7 @@ std::vector<double> SampleLimiterRegion(const LimiterDbGainCurve* limiter) {
void PrecomputeKneeApproxParams(const LimiterDbGainCurve* limiter,
test::InterpolatedParameters* parameters) {
static_assert(kInterpolatedGainCurveKneePoints > 2, "");
// Get |kInterpolatedGainCurveKneePoints| - 1 equally spaced points.
// Get `kInterpolatedGainCurveKneePoints` - 1 equally spaced points.
const std::vector<double> points = test::LinSpace(
limiter->knee_start_linear(), limiter->limiter_start_linear(),
kInterpolatedGainCurveKneePoints - 1);

View File

@ -29,8 +29,8 @@ namespace test {
// Knee and beyond-knee regions approximation parameters.
// The gain curve is approximated as a piece-wise linear function.
// |approx_params_x_| are the boundaries between adjacent linear pieces,
// |approx_params_m_| and |approx_params_q_| are the slope and the y-intercept
// `approx_params_x_` are the boundaries between adjacent linear pieces,
// `approx_params_m_` and `approx_params_q_` are the slope and the y-intercept
// values of each piece.
struct InterpolatedParameters {
std::array<float, kInterpolatedGainCurveTotalPoints>

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/cpu_features.h"
#include "rtc_base/strings/string_builder.h"
#include "rtc_base/system/arch.h"
#include "system_wrappers/include/cpu_features_wrapper.h"
namespace webrtc {
std::string AvailableCpuFeatures::ToString() const {
char buf[64];
rtc::SimpleStringBuilder builder(buf);
bool first = true;
if (sse2) {
builder << (first ? "SSE2" : "_SSE2");
first = false;
}
if (avx2) {
builder << (first ? "AVX2" : "_AVX2");
first = false;
}
if (neon) {
builder << (first ? "NEON" : "_NEON");
first = false;
}
if (first) {
return "none";
}
return builder.str();
}
// Detects available CPU features.
AvailableCpuFeatures GetAvailableCpuFeatures() {
#if defined(WEBRTC_ARCH_X86_FAMILY)
return {/*sse2=*/GetCPUInfo(kSSE2) != 0,
/*avx2=*/GetCPUInfo(kAVX2) != 0,
/*neon=*/false};
#elif defined(WEBRTC_HAS_NEON)
return {/*sse2=*/false,
/*avx2=*/false,
/*neon=*/true};
#else
return {/*sse2=*/false,
/*avx2=*/false,
/*neon=*/false};
#endif
}
AvailableCpuFeatures NoAvailableCpuFeatures() {
return {/*sse2=*/false, /*avx2=*/false, /*neon=*/false};
}
} // namespace webrtc

View File

@ -0,0 +1,39 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
#define MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
#include <string>
namespace webrtc {
// Collection of flags indicating which CPU features are available on the
// current platform. True means available.
struct AvailableCpuFeatures {
AvailableCpuFeatures(bool sse2, bool avx2, bool neon)
: sse2(sse2), avx2(avx2), neon(neon) {}
// Intel.
bool sse2;
bool avx2;
// ARM.
bool neon;
std::string ToString() const;
};
// Detects what CPU features are available.
AvailableCpuFeatures GetAvailableCpuFeatures();
// Returns the CPU feature flags all set to false.
AvailableCpuFeatures NoAvailableCpuFeatures();
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_

View File

@ -1,99 +0,0 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/down_sampler.h"
#include <string.h>
#include <algorithm>
#include "modules/audio_processing/agc2/biquad_filter.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr int kChunkSizeMs = 10;
constexpr int kSampleRate8kHz = 8000;
constexpr int kSampleRate16kHz = 16000;
constexpr int kSampleRate32kHz = 32000;
constexpr int kSampleRate48kHz = 48000;
// Bandlimiter coefficients computed based on that only
// the first 40 bins of the spectrum for the downsampled
// signal are used.
// [B,A] = butter(2,(41/64*4000)/8000)
const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_16kHz = {
{0.1455f, 0.2911f, 0.1455f},
{-0.6698f, 0.2520f}};
// [B,A] = butter(2,(41/64*4000)/16000)
const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_32kHz = {
{0.0462f, 0.0924f, 0.0462f},
{-1.3066f, 0.4915f}};
// [B,A] = butter(2,(41/64*4000)/24000)
const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_48kHz = {
{0.0226f, 0.0452f, 0.0226f},
{-1.5320f, 0.6224f}};
} // namespace
DownSampler::DownSampler(ApmDataDumper* data_dumper)
: data_dumper_(data_dumper) {
Initialize(48000);
}
void DownSampler::Initialize(int sample_rate_hz) {
RTC_DCHECK(
sample_rate_hz == kSampleRate8kHz || sample_rate_hz == kSampleRate16kHz ||
sample_rate_hz == kSampleRate32kHz || sample_rate_hz == kSampleRate48kHz);
sample_rate_hz_ = sample_rate_hz;
down_sampling_factor_ = rtc::CheckedDivExact(sample_rate_hz_, 8000);
/// Note that the down sampling filter is not used if the sample rate is 8
/// kHz.
if (sample_rate_hz_ == kSampleRate16kHz) {
low_pass_filter_.Initialize(kLowPassFilterCoefficients_16kHz);
} else if (sample_rate_hz_ == kSampleRate32kHz) {
low_pass_filter_.Initialize(kLowPassFilterCoefficients_32kHz);
} else if (sample_rate_hz_ == kSampleRate48kHz) {
low_pass_filter_.Initialize(kLowPassFilterCoefficients_48kHz);
}
}
void DownSampler::DownSample(rtc::ArrayView<const float> in,
rtc::ArrayView<float> out) {
data_dumper_->DumpWav("lc_down_sampler_input", in, sample_rate_hz_, 1);
RTC_DCHECK_EQ(sample_rate_hz_ * kChunkSizeMs / 1000, in.size());
RTC_DCHECK_EQ(kSampleRate8kHz * kChunkSizeMs / 1000, out.size());
const size_t kMaxNumFrames = kSampleRate48kHz * kChunkSizeMs / 1000;
float x[kMaxNumFrames];
// Band-limit the signal to 4 kHz.
if (sample_rate_hz_ != kSampleRate8kHz) {
low_pass_filter_.Process(in, rtc::ArrayView<float>(x, in.size()));
// Downsample the signal.
size_t k = 0;
for (size_t j = 0; j < out.size(); ++j) {
RTC_DCHECK_GT(kMaxNumFrames, k);
out[j] = x[k];
k += down_sampling_factor_;
}
} else {
std::copy(in.data(), in.data() + in.size(), out.data());
}
data_dumper_->DumpWav("lc_down_sampler_output", out, kSampleRate8kHz, 1);
}
} // namespace webrtc

View File

@ -1,42 +0,0 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_
#include "api/array_view.h"
#include "modules/audio_processing/agc2/biquad_filter.h"
namespace webrtc {
class ApmDataDumper;
class DownSampler {
public:
explicit DownSampler(ApmDataDumper* data_dumper);
DownSampler() = delete;
DownSampler(const DownSampler&) = delete;
DownSampler& operator=(const DownSampler&) = delete;
void Initialize(int sample_rate_hz);
void DownSample(rtc::ArrayView<const float> in, rtc::ArrayView<float> out);
private:
ApmDataDumper* data_dumper_;
int sample_rate_hz_;
int down_sampling_factor_;
BiQuadFilter low_pass_filter_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_

View File

@ -20,12 +20,21 @@
namespace webrtc {
namespace {
constexpr float kInitialFilterStateLevel = 0.f;
constexpr float kInitialFilterStateLevel = 0.0f;
// Instant attack.
constexpr float kAttackFilterConstant = 0.0f;
// Limiter decay constant.
// Computed as `10 ** (-1/20 * subframe_duration / kDecayMs)` where:
// - `subframe_duration` is `kFrameDurationMs / kSubFramesInFrame`;
// - `kDecayMs` is defined in agc2_testing_common.h.
constexpr float kDecayFilterConstant = 0.9971259f;
} // namespace
FixedDigitalLevelEstimator::FixedDigitalLevelEstimator(
size_t sample_rate_hz,
int sample_rate_hz,
ApmDataDumper* apm_data_dumper)
: apm_data_dumper_(apm_data_dumper),
filter_state_level_(kInitialFilterStateLevel) {
@ -49,11 +58,11 @@ std::array<float, kSubFramesInFrame> FixedDigitalLevelEstimator::ComputeLevel(
// Compute max envelope without smoothing.
std::array<float, kSubFramesInFrame> envelope{};
for (size_t channel_idx = 0; channel_idx < float_frame.num_channels();
for (int channel_idx = 0; channel_idx < float_frame.num_channels();
++channel_idx) {
const auto channel = float_frame.channel(channel_idx);
for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
for (size_t sample_in_sub_frame = 0;
for (int sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
for (int sample_in_sub_frame = 0;
sample_in_sub_frame < samples_in_sub_frame_; ++sample_in_sub_frame) {
envelope[sub_frame] =
std::max(envelope[sub_frame],
@ -66,14 +75,14 @@ std::array<float, kSubFramesInFrame> FixedDigitalLevelEstimator::ComputeLevel(
// Make sure envelope increases happen one step earlier so that the
// corresponding *gain decrease* doesn't miss a sudden signal
// increase due to interpolation.
for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame - 1; ++sub_frame) {
for (int sub_frame = 0; sub_frame < kSubFramesInFrame - 1; ++sub_frame) {
if (envelope[sub_frame] < envelope[sub_frame + 1]) {
envelope[sub_frame] = envelope[sub_frame + 1];
}
}
// Add attack / decay smoothing.
for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
for (int sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
const float envelope_value = envelope[sub_frame];
if (envelope_value > filter_state_level_) {
envelope[sub_frame] = envelope_value * (1 - kAttackFilterConstant) +
@ -97,9 +106,9 @@ std::array<float, kSubFramesInFrame> FixedDigitalLevelEstimator::ComputeLevel(
return envelope;
}
void FixedDigitalLevelEstimator::SetSampleRate(size_t sample_rate_hz) {
samples_in_frame_ = rtc::CheckedDivExact(sample_rate_hz * kFrameDurationMs,
static_cast<size_t>(1000));
void FixedDigitalLevelEstimator::SetSampleRate(int sample_rate_hz) {
samples_in_frame_ =
rtc::CheckedDivExact(sample_rate_hz * kFrameDurationMs, 1000);
samples_in_sub_frame_ =
rtc::CheckedDivExact(samples_in_frame_, kSubFramesInFrame);
CheckParameterCombination();

View File

@ -16,7 +16,6 @@
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
@ -31,9 +30,13 @@ class FixedDigitalLevelEstimator {
// kSubFramesInSample. For kFrameDurationMs=10 and
// kSubFramesInSample=20, this means that sample_rate_hz has to be
// divisible by 2000.
FixedDigitalLevelEstimator(size_t sample_rate_hz,
FixedDigitalLevelEstimator(int sample_rate_hz,
ApmDataDumper* apm_data_dumper);
FixedDigitalLevelEstimator(const FixedDigitalLevelEstimator&) = delete;
FixedDigitalLevelEstimator& operator=(const FixedDigitalLevelEstimator&) =
delete;
// The input is assumed to be in FloatS16 format. Scaled input will
// produce similarly scaled output. A frame of with kFrameDurationMs
// ms of audio produces a level estimates in the same scale. The
@ -43,7 +46,7 @@ class FixedDigitalLevelEstimator {
// Rate may be changed at any time (but not concurrently) from the
// value passed to the constructor. The class is not thread safe.
void SetSampleRate(size_t sample_rate_hz);
void SetSampleRate(int sample_rate_hz);
// Resets the level estimator internal state.
void Reset();
@ -55,10 +58,8 @@ class FixedDigitalLevelEstimator {
ApmDataDumper* const apm_data_dumper_ = nullptr;
float filter_state_level_;
size_t samples_in_frame_;
size_t samples_in_sub_frame_;
RTC_DISALLOW_COPY_AND_ASSIGN(FixedDigitalLevelEstimator);
int samples_in_frame_;
int samples_in_sub_frame_;
};
} // namespace webrtc

View File

@ -1,101 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/fixed_gain_controller.h"
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
// Returns true when the gain factor is so close to 1 that it would
// not affect int16 samples.
bool CloseToOne(float gain_factor) {
return 1.f - 1.f / kMaxFloatS16Value <= gain_factor &&
gain_factor <= 1.f + 1.f / kMaxFloatS16Value;
}
} // namespace
FixedGainController::FixedGainController(ApmDataDumper* apm_data_dumper)
: FixedGainController(apm_data_dumper, "Agc2") {}
FixedGainController::FixedGainController(ApmDataDumper* apm_data_dumper,
std::string histogram_name_prefix)
: apm_data_dumper_(apm_data_dumper),
limiter_(48000, apm_data_dumper_, histogram_name_prefix) {
// Do update histograms.xml when adding name prefixes.
RTC_DCHECK(histogram_name_prefix == "" || histogram_name_prefix == "Test" ||
histogram_name_prefix == "AudioMixer" ||
histogram_name_prefix == "Agc2");
}
void FixedGainController::SetGain(float gain_to_apply_db) {
// Changes in gain_to_apply_ cause discontinuities. We assume
// gain_to_apply_ is set in the beginning of the call. If it is
// frequently changed, we should add interpolation between the
// values.
// The gain
RTC_DCHECK_LE(-50.f, gain_to_apply_db);
RTC_DCHECK_LE(gain_to_apply_db, 50.f);
const float previous_applied_gained = gain_to_apply_;
gain_to_apply_ = DbToRatio(gain_to_apply_db);
RTC_DCHECK_LT(0.f, gain_to_apply_);
RTC_DLOG(LS_INFO) << "Gain to apply: " << gain_to_apply_db << " db.";
// Reset the gain curve applier to quickly react on abrupt level changes
// caused by large changes of the applied gain.
if (previous_applied_gained != gain_to_apply_) {
limiter_.Reset();
}
}
void FixedGainController::SetSampleRate(size_t sample_rate_hz) {
limiter_.SetSampleRate(sample_rate_hz);
}
void FixedGainController::Process(AudioFrameView<float> signal) {
// Apply fixed digital gain. One of the
// planned usages of the FGC is to only use the limiter. In that
// case, the gain would be 1.0. Not doing the multiplications speeds
// it up considerably. Hence the check.
if (!CloseToOne(gain_to_apply_)) {
for (size_t k = 0; k < signal.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = signal.channel(k);
for (auto& sample : channel_view) {
sample *= gain_to_apply_;
}
}
}
// Use the limiter.
limiter_.Process(signal);
// Dump data for debug.
const auto channel_view = signal.channel(0);
apm_data_dumper_->DumpRaw("agc2_fixed_digital_gain_curve_applier",
channel_view.size(), channel_view.data());
// Hard-clipping.
for (size_t k = 0; k < signal.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = signal.channel(k);
for (auto& sample : channel_view) {
sample = rtc::SafeClamp(sample, kMinFloatS16Value, kMaxFloatS16Value);
}
}
}
float FixedGainController::LastAudioLevel() const {
return limiter_.LastAudioLevel();
}
} // namespace webrtc

View File

@ -25,7 +25,7 @@ bool GainCloseToOne(float gain_factor) {
}
void ClipSignal(AudioFrameView<float> signal) {
for (size_t k = 0; k < signal.num_channels(); ++k) {
for (int k = 0; k < signal.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = signal.channel(k);
for (auto& sample : channel_view) {
sample = rtc::SafeClamp(sample, kMinFloatS16Value, kMaxFloatS16Value);
@ -45,7 +45,7 @@ void ApplyGainWithRamping(float last_gain_linear,
// Gain is constant and different from 1.
if (last_gain_linear == gain_at_end_of_frame_linear) {
for (size_t k = 0; k < float_frame.num_channels(); ++k) {
for (int k = 0; k < float_frame.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = float_frame.channel(k);
for (auto& sample : channel_view) {
sample *= gain_at_end_of_frame_linear;
@ -58,8 +58,8 @@ void ApplyGainWithRamping(float last_gain_linear,
const float increment = (gain_at_end_of_frame_linear - last_gain_linear) *
inverse_samples_per_channel;
float gain = last_gain_linear;
for (size_t i = 0; i < float_frame.samples_per_channel(); ++i) {
for (size_t ch = 0; ch < float_frame.num_channels(); ++ch) {
for (int i = 0; i < float_frame.samples_per_channel(); ++i) {
for (int ch = 0; ch < float_frame.num_channels(); ++ch) {
float_frame.channel(ch)[i] *= gain;
}
gain += increment;
@ -88,12 +88,13 @@ void GainApplier::ApplyGain(AudioFrameView<float> signal) {
}
}
// TODO(bugs.webrtc.org/7494): Remove once switched to gains in dB.
void GainApplier::SetGainFactor(float gain_factor) {
RTC_DCHECK_GT(gain_factor, 0.f);
current_gain_factor_ = gain_factor;
}
void GainApplier::Initialize(size_t samples_per_channel) {
void GainApplier::Initialize(int samples_per_channel) {
RTC_DCHECK_GT(samples_per_channel, 0);
samples_per_channel_ = static_cast<int>(samples_per_channel);
inverse_samples_per_channel_ = 1.f / samples_per_channel_;

View File

@ -25,7 +25,7 @@ class GainApplier {
float GetGainFactor() const { return current_gain_factor_; }
private:
void Initialize(size_t samples_per_channel);
void Initialize(int samples_per_channel);
// Whether to clip samples after gain is applied. If 'true', result
// will fit in FloatS16 range.

View File

@ -0,0 +1,46 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_GAIN_MAP_INTERNAL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_GAIN_MAP_INTERNAL_H_
namespace webrtc {
static constexpr int kGainMapSize = 256;
// Maps input volumes, which are values in the [0, 255] range, to gains in dB.
// The values below are generated with numpy as follows:
// SI = 2 # Initial slope.
// SF = 0.25 # Final slope.
// D = 8/256 # Quantization factor.
// x = np.linspace(0, 255, 256) # Input volumes.
// y = (SF * x + (SI - SF) * (1 - np.exp(-D*x)) / D - 56).round()
static const int kGainMap[kGainMapSize] = {
-56, -54, -52, -50, -48, -47, -45, -43, -42, -40, -38, -37, -35, -34, -33,
-31, -30, -29, -27, -26, -25, -24, -23, -22, -20, -19, -18, -17, -16, -15,
-14, -14, -13, -12, -11, -10, -9, -8, -8, -7, -6, -5, -5, -4, -3,
-2, -2, -1, 0, 0, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6,
6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13,
13, 14, 14, 15, 15, 15, 16, 16, 17, 17, 17, 18, 18, 18, 19,
19, 19, 20, 20, 21, 21, 21, 22, 22, 22, 23, 23, 23, 24, 24,
24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 28,
29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 32, 32, 32, 32, 33,
33, 33, 33, 34, 34, 34, 35, 35, 35, 35, 36, 36, 36, 36, 37,
37, 37, 38, 38, 38, 38, 39, 39, 39, 39, 40, 40, 40, 40, 41,
41, 41, 41, 42, 42, 42, 42, 43, 43, 43, 44, 44, 44, 44, 45,
45, 45, 45, 46, 46, 46, 46, 47, 47, 47, 47, 48, 48, 48, 48,
49, 49, 49, 49, 50, 50, 50, 50, 51, 51, 51, 51, 52, 52, 52,
52, 53, 53, 53, 53, 54, 54, 54, 54, 55, 55, 55, 55, 56, 56,
56, 56, 57, 57, 57, 57, 58, 58, 58, 58, 59, 59, 59, 59, 60,
60, 60, 60, 61, 61, 61, 61, 62, 62, 62, 62, 63, 63, 63, 63,
64};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_GAIN_MAP_INTERNAL_H_

View File

@ -0,0 +1,580 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/input_volume_controller.h"
#include <algorithm>
#include <cmath>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/gain_map_internal.h"
#include "modules/audio_processing/agc2/input_volume_stats_reporter.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
#include "system_wrappers/include/field_trial.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
namespace {
// Amount of error we tolerate in the microphone input volume (presumably due to
// OS quantization) before we assume the user has manually adjusted the volume.
constexpr int kVolumeQuantizationSlack = 25;
constexpr int kMaxInputVolume = 255;
static_assert(kGainMapSize > kMaxInputVolume, "gain map too small");
// Maximum absolute RMS error.
constexpr int KMaxAbsRmsErrorDbfs = 15;
static_assert(KMaxAbsRmsErrorDbfs > 0, "");
using Agc1ClippingPredictorConfig = AudioProcessing::Config::GainController1::
AnalogGainController::ClippingPredictor;
// TODO(webrtc:7494): Hardcode clipping predictor parameters and remove this
// function after no longer needed in the ctor.
Agc1ClippingPredictorConfig CreateClippingPredictorConfig(bool enabled) {
Agc1ClippingPredictorConfig config;
config.enabled = enabled;
return config;
}
// Returns an input volume in the [`min_input_volume`, `kMaxInputVolume`] range
// that reduces `gain_error_db`, which is a gain error estimated when
// `input_volume` was applied, according to a fixed gain map.
int ComputeVolumeUpdate(int gain_error_db,
int input_volume,
int min_input_volume) {
RTC_DCHECK_GE(input_volume, 0);
RTC_DCHECK_LE(input_volume, kMaxInputVolume);
if (gain_error_db == 0) {
return input_volume;
}
int new_volume = input_volume;
if (gain_error_db > 0) {
while (kGainMap[new_volume] - kGainMap[input_volume] < gain_error_db &&
new_volume < kMaxInputVolume) {
++new_volume;
}
} else {
while (kGainMap[new_volume] - kGainMap[input_volume] > gain_error_db &&
new_volume > min_input_volume) {
--new_volume;
}
}
return new_volume;
}
// Returns the proportion of samples in the buffer which are at full-scale
// (and presumably clipped).
float ComputeClippedRatio(const float* const* audio,
size_t num_channels,
size_t samples_per_channel) {
RTC_DCHECK_GT(samples_per_channel, 0);
int num_clipped = 0;
for (size_t ch = 0; ch < num_channels; ++ch) {
int num_clipped_in_ch = 0;
for (size_t i = 0; i < samples_per_channel; ++i) {
RTC_DCHECK(audio[ch]);
if (audio[ch][i] >= 32767.0f || audio[ch][i] <= -32768.0f) {
++num_clipped_in_ch;
}
}
num_clipped = std::max(num_clipped, num_clipped_in_ch);
}
return static_cast<float>(num_clipped) / (samples_per_channel);
}
void LogClippingMetrics(int clipping_rate) {
RTC_LOG(LS_INFO) << "[AGC2] Input clipping rate: " << clipping_rate << "%";
RTC_HISTOGRAM_COUNTS_LINEAR(/*name=*/"WebRTC.Audio.Agc.InputClippingRate",
/*sample=*/clipping_rate, /*min=*/0, /*max=*/100,
/*bucket_count=*/50);
}
// Compares `speech_level_dbfs` to the [`target_range_min_dbfs`,
// `target_range_max_dbfs`] range and returns the error to be compensated via
// input volume adjustment. Returns a positive value when the level is below
// the range, a negative value when the level is above the range, zero
// otherwise.
int GetSpeechLevelRmsErrorDb(float speech_level_dbfs,
int target_range_min_dbfs,
int target_range_max_dbfs) {
constexpr float kMinSpeechLevelDbfs = -90.0f;
constexpr float kMaxSpeechLevelDbfs = 30.0f;
RTC_DCHECK_GE(speech_level_dbfs, kMinSpeechLevelDbfs);
RTC_DCHECK_LE(speech_level_dbfs, kMaxSpeechLevelDbfs);
speech_level_dbfs = rtc::SafeClamp<float>(
speech_level_dbfs, kMinSpeechLevelDbfs, kMaxSpeechLevelDbfs);
int rms_error_db = 0;
if (speech_level_dbfs > target_range_max_dbfs) {
rms_error_db = std::round(target_range_max_dbfs - speech_level_dbfs);
} else if (speech_level_dbfs < target_range_min_dbfs) {
rms_error_db = std::round(target_range_min_dbfs - speech_level_dbfs);
}
return rms_error_db;
}
} // namespace
MonoInputVolumeController::MonoInputVolumeController(
int min_input_volume_after_clipping,
int min_input_volume,
int update_input_volume_wait_frames,
float speech_probability_threshold,
float speech_ratio_threshold)
: min_input_volume_(min_input_volume),
min_input_volume_after_clipping_(min_input_volume_after_clipping),
max_input_volume_(kMaxInputVolume),
update_input_volume_wait_frames_(
std::max(update_input_volume_wait_frames, 1)),
speech_probability_threshold_(speech_probability_threshold),
speech_ratio_threshold_(speech_ratio_threshold) {
RTC_DCHECK_GE(min_input_volume_, 0);
RTC_DCHECK_LE(min_input_volume_, 255);
RTC_DCHECK_GE(min_input_volume_after_clipping_, 0);
RTC_DCHECK_LE(min_input_volume_after_clipping_, 255);
RTC_DCHECK_GE(max_input_volume_, 0);
RTC_DCHECK_LE(max_input_volume_, 255);
RTC_DCHECK_GE(update_input_volume_wait_frames_, 0);
RTC_DCHECK_GE(speech_probability_threshold_, 0.0f);
RTC_DCHECK_LE(speech_probability_threshold_, 1.0f);
RTC_DCHECK_GE(speech_ratio_threshold_, 0.0f);
RTC_DCHECK_LE(speech_ratio_threshold_, 1.0f);
}
MonoInputVolumeController::~MonoInputVolumeController() = default;
void MonoInputVolumeController::Initialize() {
max_input_volume_ = kMaxInputVolume;
capture_output_used_ = true;
check_volume_on_next_process_ = true;
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
is_first_frame_ = true;
}
// A speeh segment is considered active if at least
// `update_input_volume_wait_frames_` new frames have been processed since the
// previous update and the ratio of non-silence frames (i.e., frames with a
// `speech_probability` higher than `speech_probability_threshold_`) is at least
// `speech_ratio_threshold_`.
void MonoInputVolumeController::Process(absl::optional<int> rms_error_db,
float speech_probability) {
if (check_volume_on_next_process_) {
check_volume_on_next_process_ = false;
// We have to wait until the first process call to check the volume,
// because Chromium doesn't guarantee it to be valid any earlier.
CheckVolumeAndReset();
}
// Count frames with a high speech probability as speech.
if (speech_probability >= speech_probability_threshold_) {
++speech_frames_since_update_input_volume_;
}
// Reset the counters and maybe update the input volume.
if (++frames_since_update_input_volume_ >= update_input_volume_wait_frames_) {
const float speech_ratio =
static_cast<float>(speech_frames_since_update_input_volume_) /
static_cast<float>(update_input_volume_wait_frames_);
// Always reset the counters regardless of whether the volume changes or
// not.
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
// Update the input volume if allowed.
if (!is_first_frame_ && speech_ratio >= speech_ratio_threshold_ &&
rms_error_db.has_value()) {
UpdateInputVolume(*rms_error_db);
}
}
is_first_frame_ = false;
}
void MonoInputVolumeController::HandleClipping(int clipped_level_step) {
RTC_DCHECK_GT(clipped_level_step, 0);
// Always decrease the maximum input volume, even if the current input volume
// is below threshold.
SetMaxLevel(std::max(min_input_volume_after_clipping_,
max_input_volume_ - clipped_level_step));
if (log_to_histograms_) {
RTC_HISTOGRAM_BOOLEAN("WebRTC.Audio.AgcClippingAdjustmentAllowed",
last_recommended_input_volume_ - clipped_level_step >=
min_input_volume_after_clipping_);
}
if (last_recommended_input_volume_ > min_input_volume_after_clipping_) {
// Don't try to adjust the input volume if we're already below the limit. As
// a consequence, if the user has brought the input volume above the limit,
// we will still not react until the postproc updates the input volume.
SetInputVolume(
std::max(min_input_volume_after_clipping_,
last_recommended_input_volume_ - clipped_level_step));
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
is_first_frame_ = false;
}
}
void MonoInputVolumeController::SetInputVolume(int new_volume) {
int applied_input_volume = recommended_input_volume_;
if (applied_input_volume == 0) {
RTC_DLOG(LS_INFO)
<< "[AGC2] The applied input volume is zero, taking no action.";
return;
}
if (applied_input_volume < 0 || applied_input_volume > kMaxInputVolume) {
RTC_LOG(LS_ERROR) << "[AGC2] Invalid value for the applied input volume: "
<< applied_input_volume;
return;
}
// Detect manual input volume adjustments by checking if the
// `applied_input_volume` is outside of the `[last_recommended_input_volume_ -
// kVolumeQuantizationSlack, last_recommended_input_volume_ +
// kVolumeQuantizationSlack]` range.
if (applied_input_volume >
last_recommended_input_volume_ + kVolumeQuantizationSlack ||
applied_input_volume <
last_recommended_input_volume_ - kVolumeQuantizationSlack) {
RTC_DLOG(LS_INFO)
<< "[AGC2] The input volume was manually adjusted. Updating "
"stored input volume from "
<< last_recommended_input_volume_ << " to " << applied_input_volume;
last_recommended_input_volume_ = applied_input_volume;
// Always allow the user to increase the volume.
if (last_recommended_input_volume_ > max_input_volume_) {
SetMaxLevel(last_recommended_input_volume_);
}
// Take no action in this case, since we can't be sure when the volume
// was manually adjusted.
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
is_first_frame_ = false;
return;
}
new_volume = std::min(new_volume, max_input_volume_);
if (new_volume == last_recommended_input_volume_) {
return;
}
recommended_input_volume_ = new_volume;
RTC_DLOG(LS_INFO) << "[AGC2] Applied input volume: " << applied_input_volume
<< " | last recommended input volume: "
<< last_recommended_input_volume_
<< " | newly recommended input volume: " << new_volume;
last_recommended_input_volume_ = new_volume;
}
void MonoInputVolumeController::SetMaxLevel(int input_volume) {
RTC_DCHECK_GE(input_volume, min_input_volume_after_clipping_);
max_input_volume_ = input_volume;
RTC_DLOG(LS_INFO) << "[AGC2] Maximum input volume updated: "
<< max_input_volume_;
}
void MonoInputVolumeController::HandleCaptureOutputUsedChange(
bool capture_output_used) {
if (capture_output_used_ == capture_output_used) {
return;
}
capture_output_used_ = capture_output_used;
if (capture_output_used) {
// When we start using the output, we should reset things to be safe.
check_volume_on_next_process_ = true;
}
}
int MonoInputVolumeController::CheckVolumeAndReset() {
int input_volume = recommended_input_volume_;
// Reasons for taking action at startup:
// 1) A person starting a call is expected to be heard.
// 2) Independent of interpretation of `input_volume` == 0 we should raise it
// so the AGC can do its job properly.
if (input_volume == 0 && !startup_) {
RTC_DLOG(LS_INFO)
<< "[AGC2] The applied input volume is zero, taking no action.";
return 0;
}
if (input_volume < 0 || input_volume > kMaxInputVolume) {
RTC_LOG(LS_ERROR) << "[AGC2] Invalid value for the applied input volume: "
<< input_volume;
return -1;
}
RTC_DLOG(LS_INFO) << "[AGC2] Initial input volume: " << input_volume;
if (input_volume < min_input_volume_) {
input_volume = min_input_volume_;
RTC_DLOG(LS_INFO)
<< "[AGC2] The initial input volume is too low, raising to "
<< input_volume;
recommended_input_volume_ = input_volume;
}
last_recommended_input_volume_ = input_volume;
startup_ = false;
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
is_first_frame_ = true;
return 0;
}
void MonoInputVolumeController::UpdateInputVolume(int rms_error_db) {
RTC_DLOG(LS_INFO) << "[AGC2] RMS error: " << rms_error_db << " dB";
// Prevent too large microphone input volume changes by clamping the RMS
// error.
rms_error_db =
rtc::SafeClamp(rms_error_db, -KMaxAbsRmsErrorDbfs, KMaxAbsRmsErrorDbfs);
if (rms_error_db == 0) {
return;
}
SetInputVolume(ComputeVolumeUpdate(
rms_error_db, last_recommended_input_volume_, min_input_volume_));
}
InputVolumeController::InputVolumeController(int num_capture_channels,
const Config& config)
: num_capture_channels_(num_capture_channels),
min_input_volume_(config.min_input_volume),
capture_output_used_(true),
clipped_level_step_(config.clipped_level_step),
clipped_ratio_threshold_(config.clipped_ratio_threshold),
clipped_wait_frames_(config.clipped_wait_frames),
clipping_predictor_(CreateClippingPredictor(
num_capture_channels,
CreateClippingPredictorConfig(config.enable_clipping_predictor))),
use_clipping_predictor_step_(
!!clipping_predictor_ &&
CreateClippingPredictorConfig(config.enable_clipping_predictor)
.use_predicted_step),
frames_since_clipped_(config.clipped_wait_frames),
clipping_rate_log_counter_(0),
clipping_rate_log_(0.0f),
target_range_max_dbfs_(config.target_range_max_dbfs),
target_range_min_dbfs_(config.target_range_min_dbfs),
channel_controllers_(num_capture_channels) {
RTC_LOG(LS_INFO)
<< "[AGC2] Input volume controller enabled. Minimum input volume: "
<< min_input_volume_;
for (auto& controller : channel_controllers_) {
controller = std::make_unique<MonoInputVolumeController>(
config.clipped_level_min, min_input_volume_,
config.update_input_volume_wait_frames,
config.speech_probability_threshold, config.speech_ratio_threshold);
}
RTC_DCHECK(!channel_controllers_.empty());
RTC_DCHECK_GT(clipped_level_step_, 0);
RTC_DCHECK_LE(clipped_level_step_, 255);
RTC_DCHECK_GT(clipped_ratio_threshold_, 0.0f);
RTC_DCHECK_LT(clipped_ratio_threshold_, 1.0f);
RTC_DCHECK_GT(clipped_wait_frames_, 0);
channel_controllers_[0]->ActivateLogging();
}
InputVolumeController::~InputVolumeController() {}
void InputVolumeController::Initialize() {
for (auto& controller : channel_controllers_) {
controller->Initialize();
}
capture_output_used_ = true;
AggregateChannelLevels();
clipping_rate_log_ = 0.0f;
clipping_rate_log_counter_ = 0;
applied_input_volume_ = absl::nullopt;
}
void InputVolumeController::AnalyzeInputAudio(int applied_input_volume,
const AudioBuffer& audio_buffer) {
RTC_DCHECK_GE(applied_input_volume, 0);
RTC_DCHECK_LE(applied_input_volume, 255);
SetAppliedInputVolume(applied_input_volume);
RTC_DCHECK_EQ(audio_buffer.num_channels(), channel_controllers_.size());
const float* const* audio = audio_buffer.channels_const();
size_t samples_per_channel = audio_buffer.num_frames();
RTC_DCHECK(audio);
AggregateChannelLevels();
if (!capture_output_used_) {
return;
}
if (!!clipping_predictor_) {
AudioFrameView<const float> frame = AudioFrameView<const float>(
audio, num_capture_channels_, static_cast<int>(samples_per_channel));
clipping_predictor_->Analyze(frame);
}
// Check for clipped samples. We do this in the preprocessing phase in order
// to catch clipped echo as well.
//
// If we find a sufficiently clipped frame, drop the current microphone
// input volume and enforce a new maximum input volume, dropped the same
// amount from the current maximum. This harsh treatment is an effort to avoid
// repeated clipped echo events.
float clipped_ratio =
ComputeClippedRatio(audio, num_capture_channels_, samples_per_channel);
clipping_rate_log_ = std::max(clipped_ratio, clipping_rate_log_);
clipping_rate_log_counter_++;
constexpr int kNumFramesIn30Seconds = 3000;
if (clipping_rate_log_counter_ == kNumFramesIn30Seconds) {
LogClippingMetrics(std::round(100.0f * clipping_rate_log_));
clipping_rate_log_ = 0.0f;
clipping_rate_log_counter_ = 0;
}
if (frames_since_clipped_ < clipped_wait_frames_) {
++frames_since_clipped_;
return;
}
const bool clipping_detected = clipped_ratio > clipped_ratio_threshold_;
bool clipping_predicted = false;
int predicted_step = 0;
if (!!clipping_predictor_) {
for (int channel = 0; channel < num_capture_channels_; ++channel) {
const auto step = clipping_predictor_->EstimateClippedLevelStep(
channel, recommended_input_volume_, clipped_level_step_,
channel_controllers_[channel]->min_input_volume_after_clipping(),
kMaxInputVolume);
if (step.has_value()) {
predicted_step = std::max(predicted_step, step.value());
clipping_predicted = true;
}
}
}
if (clipping_detected) {
RTC_DLOG(LS_INFO) << "[AGC2] Clipping detected (ratio: " << clipped_ratio
<< ")";
}
int step = clipped_level_step_;
if (clipping_predicted) {
predicted_step = std::max(predicted_step, clipped_level_step_);
RTC_DLOG(LS_INFO) << "[AGC2] Clipping predicted (volume down step: "
<< predicted_step << ")";
if (use_clipping_predictor_step_) {
step = predicted_step;
}
}
if (clipping_detected ||
(clipping_predicted && use_clipping_predictor_step_)) {
for (auto& state_ch : channel_controllers_) {
state_ch->HandleClipping(step);
}
frames_since_clipped_ = 0;
if (!!clipping_predictor_) {
clipping_predictor_->Reset();
}
}
AggregateChannelLevels();
}
absl::optional<int> InputVolumeController::RecommendInputVolume(
float speech_probability,
absl::optional<float> speech_level_dbfs) {
// Only process if applied input volume is set.
if (!applied_input_volume_.has_value()) {
RTC_LOG(LS_ERROR) << "[AGC2] Applied input volume not set.";
return absl::nullopt;
}
AggregateChannelLevels();
const int volume_after_clipping_handling = recommended_input_volume_;
if (!capture_output_used_) {
return applied_input_volume_;
}
absl::optional<int> rms_error_db;
if (speech_level_dbfs.has_value()) {
// Compute the error for all frames (both speech and non-speech frames).
rms_error_db = GetSpeechLevelRmsErrorDb(
*speech_level_dbfs, target_range_min_dbfs_, target_range_max_dbfs_);
}
for (auto& controller : channel_controllers_) {
controller->Process(rms_error_db, speech_probability);
}
AggregateChannelLevels();
if (volume_after_clipping_handling != recommended_input_volume_) {
// The recommended input volume was adjusted in order to match the target
// level.
UpdateHistogramOnRecommendedInputVolumeChangeToMatchTarget(
recommended_input_volume_);
}
applied_input_volume_ = absl::nullopt;
return recommended_input_volume();
}
void InputVolumeController::HandleCaptureOutputUsedChange(
bool capture_output_used) {
for (auto& controller : channel_controllers_) {
controller->HandleCaptureOutputUsedChange(capture_output_used);
}
capture_output_used_ = capture_output_used;
}
void InputVolumeController::SetAppliedInputVolume(int input_volume) {
applied_input_volume_ = input_volume;
for (auto& controller : channel_controllers_) {
controller->set_stream_analog_level(input_volume);
}
AggregateChannelLevels();
}
void InputVolumeController::AggregateChannelLevels() {
int new_recommended_input_volume =
channel_controllers_[0]->recommended_analog_level();
channel_controlling_gain_ = 0;
for (size_t ch = 1; ch < channel_controllers_.size(); ++ch) {
int input_volume = channel_controllers_[ch]->recommended_analog_level();
if (input_volume < new_recommended_input_volume) {
new_recommended_input_volume = input_volume;
channel_controlling_gain_ = static_cast<int>(ch);
}
}
// Enforce the minimum input volume when a recommendation is made.
if (applied_input_volume_.has_value() && *applied_input_volume_ > 0) {
new_recommended_input_volume =
std::max(new_recommended_input_volume, min_input_volume_);
}
recommended_input_volume_ = new_recommended_input_volume;
}
} // namespace webrtc

View File

@ -0,0 +1,282 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_CONTROLLER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_CONTROLLER_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/clipping_predictor.h"
#include "modules/audio_processing/audio_buffer.h"
#include "modules/audio_processing/include/audio_processing.h"
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
class MonoInputVolumeController;
// The input volume controller recommends what volume to use, handles volume
// changes and clipping detection and prediction. In particular, it handles
// changes triggered by the user (e.g., volume set to zero by a HW mute button).
// This class is not thread-safe.
// TODO(bugs.webrtc.org/7494): Use applied/recommended input volume naming
// convention.
class InputVolumeController final {
public:
// Config for the constructor.
struct Config {
// Minimum input volume that can be recommended. Not enforced when the
// applied input volume is zero outside startup.
int min_input_volume = 20;
// Lowest input volume level that will be applied in response to clipping.
int clipped_level_min = 70;
// Amount input volume level is lowered with every clipping event. Limited
// to (0, 255].
int clipped_level_step = 15;
// Proportion of clipped samples required to declare a clipping event.
// Limited to (0.0f, 1.0f).
float clipped_ratio_threshold = 0.1f;
// Time in frames to wait after a clipping event before checking again.
// Limited to values higher than 0.
int clipped_wait_frames = 300;
// Enables clipping prediction functionality.
bool enable_clipping_predictor = false;
// Speech level target range (dBFS). If the speech level is in the range
// [`target_range_min_dbfs`, `target_range_max_dbfs`], no input volume
// adjustments are done based on the speech level. For speech levels below
// and above the range, the targets `target_range_min_dbfs` and
// `target_range_max_dbfs` are used, respectively.
int target_range_max_dbfs = -30;
int target_range_min_dbfs = -50;
// Number of wait frames between the recommended input volume updates.
int update_input_volume_wait_frames = 100;
// Speech probability threshold: speech probabilities below the threshold
// are considered silence. Limited to [0.0f, 1.0f].
float speech_probability_threshold = 0.7f;
// Minimum speech frame ratio for volume updates to be allowed. Limited to
// [0.0f, 1.0f].
float speech_ratio_threshold = 0.6f;
};
// Ctor. `num_capture_channels` specifies the number of channels for the audio
// passed to `AnalyzePreProcess()` and `Process()`. Clamps
// `config.startup_min_level` in the [12, 255] range.
InputVolumeController(int num_capture_channels, const Config& config);
~InputVolumeController();
InputVolumeController(const InputVolumeController&) = delete;
InputVolumeController& operator=(const InputVolumeController&) = delete;
// TODO(webrtc:7494): Integrate initialization into ctor and remove.
void Initialize();
// Analyzes `audio_buffer` before `RecommendInputVolume()` is called so tha
// the analysis can be performed before digital processing operations take
// place (e.g., echo cancellation). The analysis consists of input clipping
// detection and prediction (if enabled).
void AnalyzeInputAudio(int applied_input_volume,
const AudioBuffer& audio_buffer);
// Adjusts the recommended input volume upwards/downwards based on the result
// of `AnalyzeInputAudio()` and on `speech_level_dbfs` (if specified). Must
// be called after `AnalyzeInputAudio()`. The value of `speech_probability`
// is expected to be in the range [0, 1] and `speech_level_dbfs` in the range
// [-90, 30] and both should be estimated after echo cancellation and noise
// suppression are applied. Returns a non-empty input volume recommendation if
// available. If `capture_output_used_` is true, returns the applied input
// volume.
absl::optional<int> RecommendInputVolume(
float speech_probability,
absl::optional<float> speech_level_dbfs);
// Stores whether the capture output will be used or not. Call when the
// capture stream output has been flagged to be used/not-used. If unused, the
// controller disregards all incoming audio.
void HandleCaptureOutputUsedChange(bool capture_output_used);
// Returns true if clipping prediction is enabled.
// TODO(bugs.webrtc.org/7494): Deprecate this method.
bool clipping_predictor_enabled() const { return !!clipping_predictor_; }
// Returns true if clipping prediction is used to adjust the input volume.
// TODO(bugs.webrtc.org/7494): Deprecate this method.
bool use_clipping_predictor_step() const {
return use_clipping_predictor_step_;
}
// Only use for testing: Use `RecommendInputVolume()` elsewhere.
// Returns the value of a member variable, needed for testing
// `AnalyzeInputAudio()`.
int recommended_input_volume() const { return recommended_input_volume_; }
// Only use for testing.
bool capture_output_used() const { return capture_output_used_; }
private:
friend class InputVolumeControllerTestHelper;
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest, MinInputVolumeDefault);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest, MinInputVolumeDisabled);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest,
MinInputVolumeOutOfRangeAbove);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest,
MinInputVolumeOutOfRangeBelow);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest, MinInputVolumeEnabled50);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerParametrizedTest,
ClippingParametersVerified);
// Sets the applied input volume and resets the recommended input volume.
void SetAppliedInputVolume(int level);
void AggregateChannelLevels();
const int num_capture_channels_;
// Minimum input volume that can be recommended.
const int min_input_volume_;
// TODO(bugs.webrtc.org/7494): Once
// `AudioProcessingImpl::recommended_stream_analog_level()` becomes a trivial
// getter, leave uninitialized.
// Recommended input volume. After `SetAppliedInputVolume()` is called it
// holds holds the observed input volume. Possibly updated by
// `AnalyzePreProcess()` and `Process()`; after these calls, holds the
// recommended input volume.
int recommended_input_volume_ = 0;
// Applied input volume. After `SetAppliedInputVolume()` is called it holds
// the current applied volume.
absl::optional<int> applied_input_volume_;
bool capture_output_used_;
// Clipping detection and prediction.
const int clipped_level_step_;
const float clipped_ratio_threshold_;
const int clipped_wait_frames_;
const std::unique_ptr<ClippingPredictor> clipping_predictor_;
const bool use_clipping_predictor_step_;
int frames_since_clipped_;
int clipping_rate_log_counter_;
float clipping_rate_log_;
// Target range minimum and maximum. If the seech level is in the range
// [`target_range_min_dbfs`, `target_range_max_dbfs`], no volume adjustments
// take place. Instead, the digital gain controller is assumed to adapt to
// compensate for the speech level RMS error.
const int target_range_max_dbfs_;
const int target_range_min_dbfs_;
// Channel controllers updating the gain upwards/downwards.
std::vector<std::unique_ptr<MonoInputVolumeController>> channel_controllers_;
int channel_controlling_gain_ = 0;
};
// TODO(bugs.webrtc.org/7494): Use applied/recommended input volume naming
// convention.
class MonoInputVolumeController {
public:
MonoInputVolumeController(int min_input_volume_after_clipping,
int min_input_volume,
int update_input_volume_wait_frames,
float speech_probability_threshold,
float speech_ratio_threshold);
~MonoInputVolumeController();
MonoInputVolumeController(const MonoInputVolumeController&) = delete;
MonoInputVolumeController& operator=(const MonoInputVolumeController&) =
delete;
void Initialize();
void HandleCaptureOutputUsedChange(bool capture_output_used);
// Sets the current input volume.
void set_stream_analog_level(int input_volume) {
recommended_input_volume_ = input_volume;
}
// Lowers the recommended input volume in response to clipping based on the
// suggested reduction `clipped_level_step`. Must be called after
// `set_stream_analog_level()`.
void HandleClipping(int clipped_level_step);
// TODO(bugs.webrtc.org/7494): Rename, audio not passed to the method anymore.
// Adjusts the recommended input volume upwards/downwards depending on the
// result of `HandleClipping()` and on `rms_error_dbfs`. Updates are only
// allowed for active speech segments and when `rms_error_dbfs` is not empty.
// Must be called after `HandleClipping()`.
void Process(absl::optional<int> rms_error_dbfs, float speech_probability);
// Returns the recommended input volume. Must be called after `Process()`.
int recommended_analog_level() const { return recommended_input_volume_; }
void ActivateLogging() { log_to_histograms_ = true; }
int min_input_volume_after_clipping() const {
return min_input_volume_after_clipping_;
}
// Only used for testing.
int min_input_volume() const { return min_input_volume_; }
private:
// Sets a new input volume, after first checking that it hasn't been updated
// by the user, in which case no action is taken.
void SetInputVolume(int new_volume);
// Sets the maximum input volume that the input volume controller is allowed
// to apply. The volume must be at least `kClippedLevelMin`.
void SetMaxLevel(int level);
int CheckVolumeAndReset();
// Updates the recommended input volume. If the volume slider needs to be
// moved, we check first if the user has adjusted it, in which case we take no
// action and cache the updated level.
void UpdateInputVolume(int rms_error_dbfs);
const int min_input_volume_;
const int min_input_volume_after_clipping_;
int max_input_volume_;
int last_recommended_input_volume_ = 0;
bool capture_output_used_ = true;
bool check_volume_on_next_process_ = true;
bool startup_ = true;
// TODO(bugs.webrtc.org/7494): Create a separate member for the applied
// input volume.
// Recommended input volume. After `set_stream_analog_level()` is
// called, it holds the observed applied input volume. Possibly updated by
// `HandleClipping()` and `Process()`; after these calls, holds the
// recommended input volume.
int recommended_input_volume_ = 0;
bool log_to_histograms_ = false;
// Counters for frames and speech frames since the last update in the
// recommended input volume.
const int update_input_volume_wait_frames_;
int frames_since_update_input_volume_ = 0;
int speech_frames_since_update_input_volume_ = 0;
bool is_first_frame_ = true;
// Speech probability threshold for a frame to be considered speech (instead
// of silence). Limited to [0.0f, 1.0f].
const float speech_probability_threshold_;
// Minimum ratio of speech frames. Limited to [0.0f, 1.0f].
const float speech_ratio_threshold_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_CONTROLLER_H_

View File

@ -0,0 +1,171 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/input_volume_stats_reporter.h"
#include <cmath>
#include "absl/strings/string_view.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
#include "rtc_base/strings/string_builder.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
namespace {
using InputVolumeType = InputVolumeStatsReporter::InputVolumeType;
constexpr int kFramesIn60Seconds = 6000;
constexpr int kMinInputVolume = 0;
constexpr int kMaxInputVolume = 255;
constexpr int kMaxUpdate = kMaxInputVolume - kMinInputVolume;
int ComputeAverageUpdate(int sum_updates, int num_updates) {
RTC_DCHECK_GE(sum_updates, 0);
RTC_DCHECK_LE(sum_updates, kMaxUpdate * kFramesIn60Seconds);
RTC_DCHECK_GE(num_updates, 0);
RTC_DCHECK_LE(num_updates, kFramesIn60Seconds);
if (num_updates == 0) {
return 0;
}
return std::round(static_cast<float>(sum_updates) /
static_cast<float>(num_updates));
}
constexpr absl::string_view MetricNamePrefix(
InputVolumeType input_volume_type) {
switch (input_volume_type) {
case InputVolumeType::kApplied:
return "WebRTC.Audio.Apm.AppliedInputVolume.";
case InputVolumeType::kRecommended:
return "WebRTC.Audio.Apm.RecommendedInputVolume.";
}
}
metrics::Histogram* CreateVolumeHistogram(InputVolumeType input_volume_type) {
char buffer[64];
rtc::SimpleStringBuilder builder(buffer);
builder << MetricNamePrefix(input_volume_type) << "OnChange";
return metrics::HistogramFactoryGetCountsLinear(/*name=*/builder.str(),
/*min=*/1,
/*max=*/kMaxInputVolume,
/*bucket_count=*/50);
}
metrics::Histogram* CreateRateHistogram(InputVolumeType input_volume_type,
absl::string_view name) {
char buffer[64];
rtc::SimpleStringBuilder builder(buffer);
builder << MetricNamePrefix(input_volume_type) << name;
return metrics::HistogramFactoryGetCountsLinear(/*name=*/builder.str(),
/*min=*/1,
/*max=*/kFramesIn60Seconds,
/*bucket_count=*/50);
}
metrics::Histogram* CreateAverageHistogram(InputVolumeType input_volume_type,
absl::string_view name) {
char buffer[64];
rtc::SimpleStringBuilder builder(buffer);
builder << MetricNamePrefix(input_volume_type) << name;
return metrics::HistogramFactoryGetCountsLinear(/*name=*/builder.str(),
/*min=*/1,
/*max=*/kMaxUpdate,
/*bucket_count=*/50);
}
} // namespace
InputVolumeStatsReporter::InputVolumeStatsReporter(InputVolumeType type)
: histograms_(
{.on_volume_change = CreateVolumeHistogram(type),
.decrease_rate = CreateRateHistogram(type, "DecreaseRate"),
.decrease_average = CreateAverageHistogram(type, "DecreaseAverage"),
.increase_rate = CreateRateHistogram(type, "IncreaseRate"),
.increase_average = CreateAverageHistogram(type, "IncreaseAverage"),
.update_rate = CreateRateHistogram(type, "UpdateRate"),
.update_average = CreateAverageHistogram(type, "UpdateAverage")}),
cannot_log_stats_(!histograms_.AllPointersSet()) {
if (cannot_log_stats_) {
RTC_LOG(LS_WARNING) << "Will not log any `" << MetricNamePrefix(type)
<< "*` histogram stats.";
}
}
InputVolumeStatsReporter::~InputVolumeStatsReporter() = default;
void InputVolumeStatsReporter::UpdateStatistics(int input_volume) {
if (cannot_log_stats_) {
// Since the stats cannot be logged, do not bother updating them.
return;
}
RTC_DCHECK_GE(input_volume, kMinInputVolume);
RTC_DCHECK_LE(input_volume, kMaxInputVolume);
if (previous_input_volume_.has_value() &&
input_volume != previous_input_volume_.value()) {
// Update stats when the input volume changes.
metrics::HistogramAdd(histograms_.on_volume_change, input_volume);
// Update stats that are periodically logged.
const int volume_change = input_volume - previous_input_volume_.value();
if (volume_change < 0) {
++volume_update_stats_.num_decreases;
volume_update_stats_.sum_decreases -= volume_change;
} else {
++volume_update_stats_.num_increases;
volume_update_stats_.sum_increases += volume_change;
}
}
// Periodically log input volume change metrics.
if (++log_volume_update_stats_counter_ >= kFramesIn60Seconds) {
LogVolumeUpdateStats();
volume_update_stats_ = {};
log_volume_update_stats_counter_ = 0;
}
previous_input_volume_ = input_volume;
}
void InputVolumeStatsReporter::LogVolumeUpdateStats() const {
// Decrease rate and average.
metrics::HistogramAdd(histograms_.decrease_rate,
volume_update_stats_.num_decreases);
if (volume_update_stats_.num_decreases > 0) {
int average_decrease = ComputeAverageUpdate(
volume_update_stats_.sum_decreases, volume_update_stats_.num_decreases);
metrics::HistogramAdd(histograms_.decrease_average, average_decrease);
}
// Increase rate and average.
metrics::HistogramAdd(histograms_.increase_rate,
volume_update_stats_.num_increases);
if (volume_update_stats_.num_increases > 0) {
int average_increase = ComputeAverageUpdate(
volume_update_stats_.sum_increases, volume_update_stats_.num_increases);
metrics::HistogramAdd(histograms_.increase_average, average_increase);
}
// Update rate and average.
int num_updates =
volume_update_stats_.num_decreases + volume_update_stats_.num_increases;
metrics::HistogramAdd(histograms_.update_rate, num_updates);
if (num_updates > 0) {
int average_update = ComputeAverageUpdate(
volume_update_stats_.sum_decreases + volume_update_stats_.sum_increases,
num_updates);
metrics::HistogramAdd(histograms_.update_average, average_update);
}
}
void UpdateHistogramOnRecommendedInputVolumeChangeToMatchTarget(int volume) {
RTC_HISTOGRAM_COUNTS_LINEAR(
"WebRTC.Audio.Apm.RecommendedInputVolume.OnChangeToMatchTarget", volume,
1, kMaxInputVolume, 50);
}
} // namespace webrtc

View File

@ -0,0 +1,96 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_STATS_REPORTER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_STATS_REPORTER_H_
#include "absl/types/optional.h"
#include "rtc_base/gtest_prod_util.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
// Input volume statistics calculator. Computes aggregate stats based on the
// framewise input volume observed by `UpdateStatistics()`. Periodically logs
// the statistics into a histogram.
class InputVolumeStatsReporter {
public:
enum class InputVolumeType {
kApplied = 0,
kRecommended = 1,
};
explicit InputVolumeStatsReporter(InputVolumeType input_volume_type);
InputVolumeStatsReporter(const InputVolumeStatsReporter&) = delete;
InputVolumeStatsReporter operator=(const InputVolumeStatsReporter&) = delete;
~InputVolumeStatsReporter();
// Updates the stats based on `input_volume`. Periodically logs the stats into
// a histogram.
void UpdateStatistics(int input_volume);
private:
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsForEmptyStats);
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsAfterNoVolumeChange);
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsAfterVolumeIncrease);
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsAfterVolumeDecrease);
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsAfterReset);
// Stores input volume update stats to enable calculation of update rate and
// average update separately for volume increases and decreases.
struct VolumeUpdateStats {
int num_decreases = 0;
int num_increases = 0;
int sum_decreases = 0;
int sum_increases = 0;
} volume_update_stats_;
// Returns a copy of the stored statistics. Use only for testing.
VolumeUpdateStats volume_update_stats() const { return volume_update_stats_; }
// Computes aggregate stat and logs them into a histogram.
void LogVolumeUpdateStats() const;
// Histograms.
struct Histograms {
metrics::Histogram* const on_volume_change;
metrics::Histogram* const decrease_rate;
metrics::Histogram* const decrease_average;
metrics::Histogram* const increase_rate;
metrics::Histogram* const increase_average;
metrics::Histogram* const update_rate;
metrics::Histogram* const update_average;
bool AllPointersSet() const {
return !!on_volume_change && !!decrease_rate && !!decrease_average &&
!!increase_rate && !!increase_average && !!update_rate &&
!!update_average;
}
} histograms_;
// True if the stats cannot be logged.
const bool cannot_log_stats_;
int log_volume_update_stats_counter_ = 0;
absl::optional<int> previous_input_volume_ = absl::nullopt;
};
// Updates the histogram that keeps track of recommended input volume changes
// required in order to match the target level in the input volume adaptation
// process.
void UpdateHistogramOnRecommendedInputVolumeChangeToMatchTarget(int volume);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_STATS_REPORTER_H_

View File

@ -13,9 +13,11 @@
#include <algorithm>
#include <iterator>
#include "absl/strings/string_view.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/strings/string_builder.h"
namespace webrtc {
@ -28,16 +30,23 @@ constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
InterpolatedGainCurve::approximation_params_q_;
InterpolatedGainCurve::InterpolatedGainCurve(ApmDataDumper* apm_data_dumper,
std::string histogram_name_prefix)
: region_logger_("WebRTC.Audio." + histogram_name_prefix +
".FixedDigitalGainCurveRegion.Identity",
"WebRTC.Audio." + histogram_name_prefix +
".FixedDigitalGainCurveRegion.Knee",
"WebRTC.Audio." + histogram_name_prefix +
".FixedDigitalGainCurveRegion.Limiter",
"WebRTC.Audio." + histogram_name_prefix +
".FixedDigitalGainCurveRegion.Saturation"),
InterpolatedGainCurve::InterpolatedGainCurve(
ApmDataDumper* apm_data_dumper,
absl::string_view histogram_name_prefix)
: region_logger_(
(rtc::StringBuilder("WebRTC.Audio.")
<< histogram_name_prefix << ".FixedDigitalGainCurveRegion.Identity")
.str(),
(rtc::StringBuilder("WebRTC.Audio.")
<< histogram_name_prefix << ".FixedDigitalGainCurveRegion.Knee")
.str(),
(rtc::StringBuilder("WebRTC.Audio.")
<< histogram_name_prefix << ".FixedDigitalGainCurveRegion.Limiter")
.str(),
(rtc::StringBuilder("WebRTC.Audio.")
<< histogram_name_prefix
<< ".FixedDigitalGainCurveRegion.Saturation")
.str()),
apm_data_dumper_(apm_data_dumper) {}
InterpolatedGainCurve::~InterpolatedGainCurve() {
@ -56,10 +65,10 @@ InterpolatedGainCurve::~InterpolatedGainCurve() {
}
InterpolatedGainCurve::RegionLogger::RegionLogger(
std::string identity_histogram_name,
std::string knee_histogram_name,
std::string limiter_histogram_name,
std::string saturation_histogram_name)
absl::string_view identity_histogram_name,
absl::string_view knee_histogram_name,
absl::string_view limiter_histogram_name,
absl::string_view saturation_histogram_name)
: identity_histogram(
metrics::HistogramFactoryGetCounts(identity_histogram_name,
1,
@ -114,7 +123,7 @@ void InterpolatedGainCurve::RegionLogger::LogRegionStats(
break;
}
default: {
RTC_NOTREACHED();
RTC_DCHECK_NOTREACHED();
}
}
}
@ -150,11 +159,11 @@ void InterpolatedGainCurve::UpdateStats(float input_level) const {
}
// Looks up a gain to apply given a non-negative input level.
// The cost of this operation depends on the region in which |input_level|
// The cost of this operation depends on the region in which `input_level`
// falls.
// For the identity and the saturation regions the cost is O(1).
// For the other regions, namely knee and limiter, the cost is
// O(2 + log2(|LightkInterpolatedGainCurveTotalPoints|), plus O(1) for the
// O(2 + log2(`LightkInterpolatedGainCurveTotalPoints`), plus O(1) for the
// linear interpolation (one product and one sum).
float InterpolatedGainCurve::LookUpGainToApply(float input_level) const {
UpdateStats(input_level);

View File

@ -12,10 +12,9 @@
#define MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
#include <array>
#include <string>
#include "absl/strings/string_view.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/constructor_magic.h"
#include "rtc_base/gtest_prod_util.h"
#include "system_wrappers/include/metrics.h"
@ -61,9 +60,12 @@ class InterpolatedGainCurve {
};
InterpolatedGainCurve(ApmDataDumper* apm_data_dumper,
std::string histogram_name_prefix);
absl::string_view histogram_name_prefix);
~InterpolatedGainCurve();
InterpolatedGainCurve(const InterpolatedGainCurve&) = delete;
InterpolatedGainCurve& operator=(const InterpolatedGainCurve&) = delete;
Stats get_stats() const { return stats_; }
// Given a non-negative input level (linear scale), a scalar factor to apply
@ -75,7 +77,7 @@ class InterpolatedGainCurve {
private:
// For comparing 'approximation_params_*_' with ones computed by
// ComputeInterpolatedGainCurve.
FRIEND_TEST_ALL_PREFIXES(AutomaticGainController2InterpolatedGainCurve,
FRIEND_TEST_ALL_PREFIXES(GainController2InterpolatedGainCurve,
CheckApproximationParams);
struct RegionLogger {
@ -84,10 +86,10 @@ class InterpolatedGainCurve {
metrics::Histogram* limiter_histogram;
metrics::Histogram* saturation_histogram;
RegionLogger(std::string identity_histogram_name,
std::string knee_histogram_name,
std::string limiter_histogram_name,
std::string saturation_histogram_name);
RegionLogger(absl::string_view identity_histogram_name,
absl::string_view knee_histogram_name,
absl::string_view limiter_histogram_name,
absl::string_view saturation_histogram_name);
~RegionLogger();
@ -143,8 +145,6 @@ class InterpolatedGainCurve {
// Stats.
mutable Stats stats_;
RTC_DISALLOW_COPY_AND_ASSIGN(InterpolatedGainCurve);
};
} // namespace webrtc

View File

@ -14,10 +14,12 @@
#include <array>
#include <cmath>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
@ -29,14 +31,14 @@ namespace {
// sub-frame, linear interpolation is replaced with a power function which
// reduces the chances of over-shooting (and hence saturation), however reducing
// the fixed gain effectiveness.
constexpr float kAttackFirstSubframeInterpolationPower = 8.f;
constexpr float kAttackFirstSubframeInterpolationPower = 8.0f;
void InterpolateFirstSubframe(float last_factor,
float current_factor,
rtc::ArrayView<float> subframe) {
const auto n = subframe.size();
constexpr auto p = kAttackFirstSubframeInterpolationPower;
for (size_t i = 0; i < n; ++i) {
const int n = rtc::dchecked_cast<int>(subframe.size());
constexpr float p = kAttackFirstSubframeInterpolationPower;
for (int i = 0; i < n; ++i) {
subframe[i] = std::pow(1.f - i / n, p) * (last_factor - current_factor) +
current_factor;
}
@ -44,10 +46,10 @@ void InterpolateFirstSubframe(float last_factor,
void ComputePerSampleSubframeFactors(
const std::array<float, kSubFramesInFrame + 1>& scaling_factors,
size_t samples_per_channel,
int samples_per_channel,
rtc::ArrayView<float> per_sample_scaling_factors) {
const size_t num_subframes = scaling_factors.size() - 1;
const size_t subframe_size =
const int num_subframes = scaling_factors.size() - 1;
const int subframe_size =
rtc::CheckedDivExact(samples_per_channel, num_subframes);
// Handle first sub-frame differently in case of attack.
@ -59,12 +61,12 @@ void ComputePerSampleSubframeFactors(
per_sample_scaling_factors.subview(0, subframe_size)));
}
for (size_t i = is_attack ? 1 : 0; i < num_subframes; ++i) {
const size_t subframe_start = i * subframe_size;
for (int i = is_attack ? 1 : 0; i < num_subframes; ++i) {
const int subframe_start = i * subframe_size;
const float scaling_start = scaling_factors[i];
const float scaling_end = scaling_factors[i + 1];
const float scaling_diff = (scaling_end - scaling_start) / subframe_size;
for (size_t j = 0; j < subframe_size; ++j) {
for (int j = 0; j < subframe_size; ++j) {
per_sample_scaling_factors[subframe_start + j] =
scaling_start + scaling_diff * j;
}
@ -73,18 +75,18 @@ void ComputePerSampleSubframeFactors(
void ScaleSamples(rtc::ArrayView<const float> per_sample_scaling_factors,
AudioFrameView<float> signal) {
const size_t samples_per_channel = signal.samples_per_channel();
const int samples_per_channel = signal.samples_per_channel();
RTC_DCHECK_EQ(samples_per_channel, per_sample_scaling_factors.size());
for (size_t i = 0; i < signal.num_channels(); ++i) {
auto channel = signal.channel(i);
for (size_t j = 0; j < samples_per_channel; ++j) {
for (int i = 0; i < signal.num_channels(); ++i) {
rtc::ArrayView<float> channel = signal.channel(i);
for (int j = 0; j < samples_per_channel; ++j) {
channel[j] = rtc::SafeClamp(channel[j] * per_sample_scaling_factors[j],
kMinFloatS16Value, kMaxFloatS16Value);
}
}
}
void CheckLimiterSampleRate(size_t sample_rate_hz) {
void CheckLimiterSampleRate(int sample_rate_hz) {
// Check that per_sample_scaling_factors_ is large enough.
RTC_DCHECK_LE(sample_rate_hz,
kMaximalNumberOfSamplesPerChannel * 1000 / kFrameDurationMs);
@ -92,9 +94,9 @@ void CheckLimiterSampleRate(size_t sample_rate_hz) {
} // namespace
Limiter::Limiter(size_t sample_rate_hz,
Limiter::Limiter(int sample_rate_hz,
ApmDataDumper* apm_data_dumper,
std::string histogram_name)
absl::string_view histogram_name)
: interp_gain_curve_(apm_data_dumper, histogram_name),
level_estimator_(sample_rate_hz, apm_data_dumper),
apm_data_dumper_(apm_data_dumper) {
@ -104,7 +106,8 @@ Limiter::Limiter(size_t sample_rate_hz,
Limiter::~Limiter() = default;
void Limiter::Process(AudioFrameView<float> signal) {
const auto level_estimate = level_estimator_.ComputeLevel(signal);
const std::array<float, kSubFramesInFrame> level_estimate =
level_estimator_.ComputeLevel(signal);
RTC_DCHECK_EQ(level_estimate.size() + 1, scaling_factors_.size());
scaling_factors_[0] = last_scaling_factor_;
@ -113,7 +116,7 @@ void Limiter::Process(AudioFrameView<float> signal) {
return interp_gain_curve_.LookUpGainToApply(x);
});
const size_t samples_per_channel = signal.samples_per_channel();
const int samples_per_channel = signal.samples_per_channel();
RTC_DCHECK_LE(samples_per_channel, kMaximalNumberOfSamplesPerChannel);
auto per_sample_scaling_factors = rtc::ArrayView<float>(
@ -125,16 +128,18 @@ void Limiter::Process(AudioFrameView<float> signal) {
last_scaling_factor_ = scaling_factors_.back();
// Dump data for debug.
apm_data_dumper_->DumpRaw("agc2_gain_curve_applier_scaling_factors",
samples_per_channel,
per_sample_scaling_factors_.data());
apm_data_dumper_->DumpRaw("agc2_limiter_last_scaling_factor",
last_scaling_factor_);
apm_data_dumper_->DumpRaw(
"agc2_limiter_region",
static_cast<int>(interp_gain_curve_.get_stats().region));
}
InterpolatedGainCurve::Stats Limiter::GetGainCurveStats() const {
return interp_gain_curve_.get_stats();
}
void Limiter::SetSampleRate(size_t sample_rate_hz) {
void Limiter::SetSampleRate(int sample_rate_hz) {
CheckLimiterSampleRate(sample_rate_hz);
level_estimator_.SetSampleRate(sample_rate_hz);
}

View File

@ -11,27 +11,26 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "modules/audio_processing/agc2/fixed_digital_level_estimator.h"
#include "modules/audio_processing/agc2/interpolated_gain_curve.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
class ApmDataDumper;
class Limiter {
public:
Limiter(size_t sample_rate_hz,
Limiter(int sample_rate_hz,
ApmDataDumper* apm_data_dumper,
std::string histogram_name_prefix);
absl::string_view histogram_name_prefix);
Limiter(const Limiter& limiter) = delete;
Limiter& operator=(const Limiter& limiter) = delete;
~Limiter();
// Applies limiter and hard-clipping to |signal|.
// Applies limiter and hard-clipping to `signal`.
void Process(AudioFrameView<float> signal);
InterpolatedGainCurve::Stats GetGainCurveStats() const;
@ -40,7 +39,7 @@ class Limiter {
// * below kMaximalNumberOfSamplesPerChannel*1000/kFrameDurationMs
// so that samples_per_channel fit in the
// per_sample_scaling_factors_ array.
void SetSampleRate(size_t sample_rate_hz);
void SetSampleRate(int sample_rate_hz);
// Resets the internal state.
void Reset();

View File

@ -105,7 +105,7 @@ double LimiterDbGainCurve::GetGainLinear(double input_level_linear) const {
input_level_linear;
}
// Computes the first derivative of GetGainLinear() in |x|.
// Computes the first derivative of GetGainLinear() in `x`.
double LimiterDbGainCurve::GetGainFirstDerivativeLinear(double x) const {
// Beyond-knee region only.
RTC_CHECK_GE(x, limiter_start_linear_ - 1e-7 * kMaxAbsFloatS16Value);

View File

@ -17,98 +17,156 @@
#include <numeric>
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr int kFramesPerSecond = 100;
float FrameEnergy(const AudioFrameView<const float>& audio) {
float energy = 0.f;
for (size_t k = 0; k < audio.num_channels(); ++k) {
float energy = 0.0f;
for (int k = 0; k < audio.num_channels(); ++k) {
float channel_energy =
std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.f,
std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.0f,
[](float a, float b) -> float { return a + b * b; });
energy = std::max(channel_energy, energy);
}
return energy;
}
float EnergyToDbfs(float signal_energy, size_t num_samples) {
const float rms = std::sqrt(signal_energy / num_samples);
return FloatS16ToDbfs(rms);
float EnergyToDbfs(float signal_energy, int num_samples) {
RTC_DCHECK_GE(signal_energy, 0.0f);
const float rms_square = signal_energy / num_samples;
constexpr float kMinDbfs = -90.30899869919436f;
if (rms_square <= 1.0f) {
return kMinDbfs;
}
return 10.0f * std::log10(rms_square) + kMinDbfs;
}
// Updates the noise floor with instant decay and slow attack. This tuning is
// specific for AGC2, so that (i) it can promptly increase the gain if the noise
// floor drops (instant decay) and (ii) in case of music or fast speech, due to
// which the noise floor can be overestimated, the gain reduction is slowed
// down.
float SmoothNoiseFloorEstimate(float current_estimate, float new_estimate) {
constexpr float kAttack = 0.5f;
if (current_estimate < new_estimate) {
// Attack phase.
return kAttack * new_estimate + (1.0f - kAttack) * current_estimate;
}
// Instant attack.
return new_estimate;
}
class NoiseFloorEstimator : public NoiseLevelEstimator {
public:
// Update the noise floor every 5 seconds.
static constexpr int kUpdatePeriodNumFrames = 500;
static_assert(kUpdatePeriodNumFrames >= 200,
"A too small value may cause noise level overestimation.");
static_assert(kUpdatePeriodNumFrames <= 1500,
"A too large value may make AGC2 slow at reacting to increased "
"noise levels.");
NoiseFloorEstimator(ApmDataDumper* data_dumper) : data_dumper_(data_dumper) {
RTC_DCHECK(data_dumper_);
// Initially assume that 48 kHz will be used. `Analyze()` will detect the
// used sample rate and call `Initialize()` again if needed.
Initialize(/*sample_rate_hz=*/48000);
}
NoiseFloorEstimator(const NoiseFloorEstimator&) = delete;
NoiseFloorEstimator& operator=(const NoiseFloorEstimator&) = delete;
~NoiseFloorEstimator() = default;
float Analyze(const AudioFrameView<const float>& frame) override {
// Detect sample rate changes.
const int sample_rate_hz =
static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
if (sample_rate_hz != sample_rate_hz_) {
Initialize(sample_rate_hz);
}
const float frame_energy = FrameEnergy(frame);
if (frame_energy <= min_noise_energy_) {
// Ignore frames when muted or below the minimum measurable energy.
if (data_dumper_)
data_dumper_->DumpRaw("agc2_noise_floor_estimator_preliminary_level",
noise_energy_);
return EnergyToDbfs(noise_energy_,
static_cast<int>(frame.samples_per_channel()));
}
if (preliminary_noise_energy_set_) {
preliminary_noise_energy_ =
std::min(preliminary_noise_energy_, frame_energy);
} else {
preliminary_noise_energy_ = frame_energy;
preliminary_noise_energy_set_ = true;
}
if (data_dumper_)
data_dumper_->DumpRaw("agc2_noise_floor_estimator_preliminary_level",
preliminary_noise_energy_);
if (counter_ == 0) {
// Full period observed.
first_period_ = false;
// Update the estimated noise floor energy with the preliminary
// estimation.
noise_energy_ = SmoothNoiseFloorEstimate(
/*current_estimate=*/noise_energy_,
/*new_estimate=*/preliminary_noise_energy_);
// Reset for a new observation period.
counter_ = kUpdatePeriodNumFrames;
preliminary_noise_energy_set_ = false;
} else if (first_period_) {
// While analyzing the signal during the initial period, continuously
// update the estimated noise energy, which is monotonic.
noise_energy_ = preliminary_noise_energy_;
counter_--;
} else {
// During the observation period it's only allowed to lower the energy.
noise_energy_ = std::min(noise_energy_, preliminary_noise_energy_);
counter_--;
}
float noise_rms_dbfs = EnergyToDbfs(
noise_energy_, static_cast<int>(frame.samples_per_channel()));
if (data_dumper_)
data_dumper_->DumpRaw("agc2_noise_rms_dbfs", noise_rms_dbfs);
return noise_rms_dbfs;
}
private:
void Initialize(int sample_rate_hz) {
sample_rate_hz_ = sample_rate_hz;
first_period_ = true;
preliminary_noise_energy_set_ = false;
// Initialize the minimum noise energy to -84 dBFS.
min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond;
preliminary_noise_energy_ = min_noise_energy_;
noise_energy_ = min_noise_energy_;
counter_ = kUpdatePeriodNumFrames;
}
ApmDataDumper* const data_dumper_;
int sample_rate_hz_;
float min_noise_energy_;
bool first_period_;
bool preliminary_noise_energy_set_;
float preliminary_noise_energy_;
float noise_energy_;
int counter_;
};
} // namespace
NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper)
: signal_classifier_(data_dumper) {
Initialize(48000);
}
NoiseLevelEstimator::~NoiseLevelEstimator() {}
void NoiseLevelEstimator::Initialize(int sample_rate_hz) {
sample_rate_hz_ = sample_rate_hz;
noise_energy_ = 1.f;
first_update_ = true;
min_noise_energy_ = sample_rate_hz * 2.f * 2.f / kFramesPerSecond;
noise_energy_hold_counter_ = 0;
signal_classifier_.Initialize(sample_rate_hz);
}
float NoiseLevelEstimator::Analyze(const AudioFrameView<const float>& frame) {
const int rate =
static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
if (rate != sample_rate_hz_) {
Initialize(rate);
}
const float frame_energy = FrameEnergy(frame);
if (frame_energy <= 0.f) {
RTC_DCHECK_GE(frame_energy, 0.f);
return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
}
if (first_update_) {
// Initialize the noise energy to the frame energy.
first_update_ = false;
return EnergyToDbfs(
noise_energy_ = std::max(frame_energy, min_noise_energy_),
frame.samples_per_channel());
}
const SignalClassifier::SignalType signal_type =
signal_classifier_.Analyze(frame.channel(0));
// Update the noise estimate in a minimum statistics-type manner.
if (signal_type == SignalClassifier::SignalType::kStationary) {
if (frame_energy > noise_energy_) {
// Leak the estimate upwards towards the frame energy if no recent
// downward update.
noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0);
if (noise_energy_hold_counter_ == 0) {
noise_energy_ = std::min(noise_energy_ * 1.01f, frame_energy);
}
} else {
// Update smoothly downwards with a limited maximum update magnitude.
noise_energy_ =
std::max(noise_energy_ * 0.9f,
noise_energy_ + 0.05f * (frame_energy - noise_energy_));
noise_energy_hold_counter_ = 1000;
}
} else {
// For a non-stationary signal, leak the estimate downwards in order to
// avoid estimate locking due to incorrect signal classification.
noise_energy_ = noise_energy_ * 0.99f;
}
// Ensure a minimum of the estimate.
return EnergyToDbfs(
noise_energy_ = std::max(noise_energy_, min_noise_energy_),
frame.samples_per_channel());
std::unique_ptr<NoiseLevelEstimator> CreateNoiseFloorEstimator(
ApmDataDumper* data_dumper) {
return std::make_unique<NoiseFloorEstimator>(data_dumper);
}
} // namespace webrtc

View File

@ -11,33 +11,26 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
#include "modules/audio_processing/agc2/signal_classifier.h"
#include <memory>
#include "modules/audio_processing/include/audio_frame_view.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
class ApmDataDumper;
// Noise level estimator interface.
class NoiseLevelEstimator {
public:
NoiseLevelEstimator(ApmDataDumper* data_dumper);
~NoiseLevelEstimator();
// Returns the estimated noise level in dBFS.
float Analyze(const AudioFrameView<const float>& frame);
private:
void Initialize(int sample_rate_hz);
int sample_rate_hz_;
float min_noise_energy_;
bool first_update_;
float noise_energy_;
int noise_energy_hold_counter_;
SignalClassifier signal_classifier_;
RTC_DISALLOW_COPY_AND_ASSIGN(NoiseLevelEstimator);
virtual ~NoiseLevelEstimator() = default;
// Analyzes a 10 ms `frame`, updates the noise level estimation and returns
// the value for the latter in dBFS.
virtual float Analyze(const AudioFrameView<const float>& frame) = 0;
};
// Creates a noise level estimator based on noise floor detection.
std::unique_ptr<NoiseLevelEstimator> CreateNoiseFloorEstimator(
ApmDataDumper* data_dumper);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_

View File

@ -1,70 +0,0 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
#include <string.h>
#include <algorithm>
#include "api/array_view.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/arraysize.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr float kMinNoisePower = 100.f;
} // namespace
NoiseSpectrumEstimator::NoiseSpectrumEstimator(ApmDataDumper* data_dumper)
: data_dumper_(data_dumper) {
Initialize();
}
void NoiseSpectrumEstimator::Initialize() {
std::fill(noise_spectrum_, noise_spectrum_ + arraysize(noise_spectrum_),
kMinNoisePower);
}
void NoiseSpectrumEstimator::Update(rtc::ArrayView<const float> spectrum,
bool first_update) {
RTC_DCHECK_EQ(65, spectrum.size());
if (first_update) {
// Initialize the noise spectral estimate with the signal spectrum.
std::copy(spectrum.data(), spectrum.data() + spectrum.size(),
noise_spectrum_);
} else {
// Smoothly update the noise spectral estimate towards the signal spectrum
// such that the magnitude of the updates are limited.
for (size_t k = 0; k < spectrum.size(); ++k) {
if (noise_spectrum_[k] < spectrum[k]) {
noise_spectrum_[k] = std::min(
1.01f * noise_spectrum_[k],
noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k]));
} else {
noise_spectrum_[k] = std::max(
0.99f * noise_spectrum_[k],
noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k]));
}
}
}
// Ensure that the noise spectal estimate does not become too low.
for (auto& v : noise_spectrum_) {
v = std::max(v, kMinNoisePower);
}
data_dumper_->DumpRaw("lc_noise_spectrum", 65, noise_spectrum_);
data_dumper_->DumpRaw("lc_signal_spectrum", spectrum);
}
} // namespace webrtc

View File

@ -1,42 +0,0 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_
#include "api/array_view.h"
namespace webrtc {
class ApmDataDumper;
class NoiseSpectrumEstimator {
public:
explicit NoiseSpectrumEstimator(ApmDataDumper* data_dumper);
NoiseSpectrumEstimator() = delete;
NoiseSpectrumEstimator(const NoiseSpectrumEstimator&) = delete;
NoiseSpectrumEstimator& operator=(const NoiseSpectrumEstimator&) = delete;
void Initialize();
void Update(rtc::ArrayView<const float> spectrum, bool first_update);
rtc::ArrayView<const float> GetNoiseSpectrum() const {
return rtc::ArrayView<const float>(noise_spectrum_);
}
private:
ApmDataDumper* data_dumper_;
float noise_spectrum_[65];
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_

View File

@ -17,6 +17,7 @@ rtc_library("rnn_vad") {
"rnn.h",
]
defines = []
if (rtc_build_with_neon && current_cpu != "arm64") {
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
cflags = [ "-mfpu=neon" ]
@ -24,16 +25,17 @@ rtc_library("rnn_vad") {
deps = [
":rnn_vad_common",
":rnn_vad_layers",
":rnn_vad_lp_residual",
":rnn_vad_pitch",
":rnn_vad_sequence_buffer",
":rnn_vad_spectral_features",
"..:biquad_filter",
"..:cpu_features",
"../../../../api:array_view",
"../../../../api:function_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:logging",
"../../../../rtc_base/system:arch",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
"//third_party/rnnoise:rnn_vad",
]
}
@ -51,16 +53,13 @@ rtc_library("rnn_vad_auto_correlation") {
]
}
rtc_library("rnn_vad_common") {
rtc_source_set("rnn_vad_common") {
# TODO(alessiob): Make this target visibility private.
visibility = [
":*",
"..:rnn_vad_with_level",
]
sources = [
"common.cc",
"common.h",
"..:vad_wrapper",
]
sources = [ "common.h" ]
deps = [
"../../../../rtc_base/system:arch",
"../../../../system_wrappers",
@ -75,23 +74,100 @@ rtc_library("rnn_vad_lp_residual") {
deps = [
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_compare",
]
}
rtc_source_set("rnn_vad_layers") {
sources = [
"rnn_fc.cc",
"rnn_fc.h",
"rnn_gru.cc",
"rnn_gru.h",
]
defines = []
if (rtc_build_with_neon && current_cpu != "arm64") {
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
cflags = [ "-mfpu=neon" ]
}
deps = [
":rnn_vad_common",
":vector_math",
"..:cpu_features",
"../../../../api:array_view",
"../../../../api:function_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_conversions",
"//third_party/rnnoise:rnn_vad",
]
if (current_cpu == "x86" || current_cpu == "x64") {
deps += [ ":vector_math_avx2" ]
}
absl_deps = [ "//third_party/abseil-cpp/absl/strings" ]
}
rtc_source_set("vector_math") {
sources = [ "vector_math.h" ]
deps = [
"..:cpu_features",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_conversions",
"../../../../rtc_base/system:arch",
]
}
if (current_cpu == "x86" || current_cpu == "x64") {
rtc_library("vector_math_avx2") {
sources = [ "vector_math_avx2.cc" ]
if (is_win) {
cflags = [ "/arch:AVX2" ]
} else {
cflags = [
"-mavx2",
"-mfma",
]
}
deps = [
":vector_math",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_conversions",
]
}
}
rtc_library("rnn_vad_pitch") {
sources = [
"pitch_info.h",
"pitch_search.cc",
"pitch_search.h",
"pitch_search_internal.cc",
"pitch_search_internal.h",
]
defines = []
if (rtc_build_with_neon && current_cpu != "arm64") {
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
cflags = [ "-mfpu=neon" ]
}
deps = [
":rnn_vad_auto_correlation",
":rnn_vad_common",
":vector_math",
"..:cpu_features",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:gtest_prod",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
"../../../../rtc_base/system:arch",
]
if (current_cpu == "x86" || current_cpu == "x64") {
deps += [ ":vector_math_avx2" ]
}
}
rtc_source_set("rnn_vad_ring_buffer") {
@ -123,6 +199,7 @@ rtc_library("rnn_vad_spectral_features") {
":rnn_vad_symmetric_matrix_buffer",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_compare",
"../../utility:pffft_wrapper",
]
}
@ -132,6 +209,7 @@ rtc_source_set("rnn_vad_symmetric_matrix_buffer") {
deps = [
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_compare",
]
}
@ -148,11 +226,11 @@ if (rtc_include_tests) {
"../../../../api:array_view",
"../../../../api:scoped_refptr",
"../../../../rtc_base:checks",
"../../../../rtc_base/system:arch",
"../../../../system_wrappers",
"../../../../rtc_base:safe_compare",
"../../../../test:fileutils",
"../../../../test:test_support",
]
absl_deps = [ "//third_party/abseil-cpp/absl/strings" ]
}
unittest_resources = [
@ -181,17 +259,28 @@ if (rtc_include_tests) {
"pitch_search_internal_unittest.cc",
"pitch_search_unittest.cc",
"ring_buffer_unittest.cc",
"rnn_fc_unittest.cc",
"rnn_gru_unittest.cc",
"rnn_unittest.cc",
"rnn_vad_unittest.cc",
"sequence_buffer_unittest.cc",
"spectral_features_internal_unittest.cc",
"spectral_features_unittest.cc",
"symmetric_matrix_buffer_unittest.cc",
"vector_math_unittest.cc",
]
defines = []
if (rtc_build_with_neon && current_cpu != "arm64") {
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
cflags = [ "-mfpu=neon" ]
}
deps = [
":rnn_vad",
":rnn_vad_auto_correlation",
":rnn_vad_common",
":rnn_vad_layers",
":rnn_vad_lp_residual",
":rnn_vad_pitch",
":rnn_vad_ring_buffer",
@ -199,20 +288,47 @@ if (rtc_include_tests) {
":rnn_vad_spectral_features",
":rnn_vad_symmetric_matrix_buffer",
":test_utils",
":vector_math",
"..:cpu_features",
"../..:audioproc_test_utils",
"../../../../api:array_view",
"../../../../common_audio/",
"../../../../rtc_base:checks",
"../../../../rtc_base:logging",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
"../../../../rtc_base:stringutils",
"../../../../rtc_base/system:arch",
"../../../../test:test_support",
"../../utility:pffft_wrapper",
"//third_party/rnnoise:rnn_vad",
]
if (current_cpu == "x86" || current_cpu == "x64") {
deps += [ ":vector_math_avx2" ]
}
absl_deps = [ "//third_party/abseil-cpp/absl/memory" ]
data = unittest_resources
if (is_ios) {
deps += [ ":unittests_bundle_data" ]
}
}
if (!build_with_chromium) {
rtc_executable("rnn_vad_tool") {
testonly = true
sources = [ "rnn_vad_tool.cc" ]
deps = [
":rnn_vad",
":rnn_vad_common",
"..:cpu_features",
"../../../../api:array_view",
"../../../../common_audio",
"../../../../rtc_base:logging",
"../../../../rtc_base:safe_compare",
"../../../../test:test_support",
"//third_party/abseil-cpp/absl/flags:flag",
"//third_party/abseil-cpp/absl/flags:parse",
]
}
}
}

View File

@ -20,7 +20,7 @@ namespace {
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
static_assert(1 << kAutoCorrelationFftOrder >
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
"");
} // namespace
@ -40,20 +40,20 @@ AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
// [ y_{m-1} ]
// x and y are sub-array of equal length; x is never moved, whereas y slides.
// The cross-correlation between y_0 and x corresponds to the auto-correlation
// for the maximum pitch period. Hence, the first value in |auto_corr| has an
// for the maximum pitch period. Hence, the first value in `auto_corr` has an
// inverted lag equal to 0 that corresponds to a lag equal to the maximum
// pitch period.
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder;
constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
"Mismatch between pitch buffer size, frame size and maximum "
"pitch period.");
static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
"The FFT length is not sufficiently big to avoid cyclic "
"convolution errors.");
auto tmp = tmp_->GetView();
@ -67,13 +67,12 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
// Compute the FFT for the sliding frames chunk. The sliding frames are
// defined as pitch_buf[i:i+kConvolutionLength] where i in
// [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
// [0, kNumLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
std::copy(pitch_buf.begin(),
pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
tmp.begin());
std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
0.f);
std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
// Convolve in the frequency domain.
@ -84,7 +83,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
// Extract the auto-correlation coefficients.
std::copy(tmp.begin() + kConvolutionLength - 1,
tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
auto_corr.begin());
}

View File

@ -31,10 +31,10 @@ class AutoCorrelationCalculator {
~AutoCorrelationCalculator();
// Computes the auto-correlation coefficients for a target pitch interval.
// |auto_corr| indexes are inverted lags.
// `auto_corr` indexes are inverted lags.
void ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
rtc::ArrayView<float, kNumLags12kHz> auto_corr);
private:
Pffft fft_;

View File

@ -1,34 +0,0 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/system/arch.h"
#include "system_wrappers/include/cpu_features_wrapper.h"
namespace webrtc {
namespace rnn_vad {
Optimization DetectOptimization() {
#if defined(WEBRTC_ARCH_X86_FAMILY)
if (GetCPUInfo(kSSE2) != 0) {
return Optimization::kSse2;
}
#endif
#if defined(WEBRTC_HAS_NEON)
return Optimization::kNeon;
#endif
return Optimization::kNone;
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -18,57 +18,58 @@ namespace rnn_vad {
constexpr double kPi = 3.14159265358979323846;
constexpr size_t kSampleRate24kHz = 24000;
constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100;
constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
constexpr int kSampleRate24kHz = 24000;
constexpr int kFrameSize10ms24kHz = kSampleRate24kHz / 100;
constexpr int kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
// Pitch buffer.
constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
constexpr int kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
constexpr int kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
constexpr int kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
// 24 kHz analysis.
// Define a higher minimum pitch period for the initial search. This is used to
// avoid searching for very short periods, for which a refinement step is
// responsible.
constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// Number of (inverted) lags during the initial pitch search phase at 24 kHz.
constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// Number of (inverted) lags during the pitch search refinement phase at 24 kHz.
constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1;
static_assert(
kRefineNumLags24kHz > kInitialNumLags24kHz,
"The refinement step must search the pitch in an extended pitch range.");
// 12 kHz analysis.
constexpr size_t kSampleRate12kHz = 12000;
constexpr size_t kFrameSize10ms12kHz = kSampleRate12kHz / 100;
constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
constexpr size_t kBufSize12kHz = kBufSize24kHz / 2;
constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
constexpr int kSampleRate12kHz = 12000;
constexpr int kFrameSize10ms12kHz = kSampleRate12kHz / 100;
constexpr int kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
constexpr int kBufSize12kHz = kBufSize24kHz / 2;
constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// The inverted lags for the pitch interval [`kInitialMinPitch12kHz`,
// `kMaxPitch12kHz`] are in the range [0, `kNumLags12kHz`].
constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// 48 kHz constants.
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2;
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;
constexpr int kMaxPitch48kHz = kMaxPitch24kHz * 2;
// Spectral features.
constexpr size_t kNumBands = 22;
constexpr size_t kNumLowerBands = 6;
constexpr int kNumBands = 22;
constexpr int kNumLowerBands = 6;
static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), "");
constexpr size_t kCepstralCoeffsHistorySize = 8;
constexpr int kCepstralCoeffsHistorySize = 8;
static_assert(kCepstralCoeffsHistorySize > 2,
"The history size must at least be 3 to compute first and second "
"derivatives.");
constexpr size_t kFeatureVectorSize = 42;
enum class Optimization { kNone, kSse2, kNeon };
// Detects what kind of optimizations to use for the code.
Optimization DetectOptimization();
constexpr int kFeatureVectorSize = 42;
} // namespace rnn_vad
} // namespace webrtc

View File

@ -19,23 +19,23 @@ namespace webrtc {
namespace rnn_vad {
namespace {
// Generated via "B, A = scipy.signal.butter(2, 30/12000, btype='highpass')"
const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = {
// Computed as `scipy.signal.butter(N=2, Wn=60/24000, btype='highpass')`.
constexpr BiQuadFilter::Config kHpfConfig24k{
{0.99446179f, -1.98892358f, 0.99446179f},
{-1.98889291f, 0.98895425f}};
} // namespace
FeaturesExtractor::FeaturesExtractor()
FeaturesExtractor::FeaturesExtractor(const AvailableCpuFeatures& cpu_features)
: use_high_pass_filter_(false),
hpf_(kHpfConfig24k),
pitch_buf_24kHz_(),
pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()),
lp_residual_(kBufSize24kHz),
lp_residual_view_(lp_residual_.data(), kBufSize24kHz),
pitch_estimator_(),
pitch_estimator_(cpu_features),
reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) {
RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size());
hpf_.Initialize(kHpfConfig24k);
Reset();
}
@ -44,8 +44,9 @@ FeaturesExtractor::~FeaturesExtractor() = default;
void FeaturesExtractor::Reset() {
pitch_buf_24kHz_.Reset();
spectral_features_extractor_.Reset();
if (use_high_pass_filter_)
if (use_high_pass_filter_) {
hpf_.Reset();
}
}
bool FeaturesExtractor::CheckSilenceComputeFeatures(
@ -55,10 +56,10 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
if (use_high_pass_filter_) {
std::array<float, kFrameSize10ms24kHz> samples_filtered;
hpf_.Process(samples, samples_filtered);
// Feed buffer with the pre-processed version of |samples|.
// Feed buffer with the pre-processed version of `samples`.
pitch_buf_24kHz_.Push(samples_filtered);
} else {
// Feed buffer with |samples|.
// Feed buffer with `samples`.
pitch_buf_24kHz_.Push(samples);
}
// Extract the LP residual.
@ -67,13 +68,12 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
// Estimate pitch on the LP-residual and write the normalized pitch period
// into the output vector (normalization based on training data stats).
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] =
0.01f * (static_cast<int>(pitch_info_48kHz_.period) - 300);
pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300);
// Extract lagged frames (according to the estimated pitch period).
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz);
auto lagged_frame = pitch_buf_24kHz_view_.subview(
kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz);
kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz);
// Analyze reference and lagged frames checking if silence has been detected
// and write the feature vector.
return spectral_features_extractor_.CheckSilenceComputeFeatures(

View File

@ -16,7 +16,6 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/biquad_filter.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
@ -27,14 +26,14 @@ namespace rnn_vad {
// Feature extractor to feed the VAD RNN.
class FeaturesExtractor {
public:
FeaturesExtractor();
explicit FeaturesExtractor(const AvailableCpuFeatures& cpu_features);
FeaturesExtractor(const FeaturesExtractor&) = delete;
FeaturesExtractor& operator=(const FeaturesExtractor&) = delete;
~FeaturesExtractor();
void Reset();
// Analyzes the samples, computes the feature vector and returns true if
// silence is detected (false if not). When silence is detected,
// |feature_vector| is partially written and therefore must not be used to
// `feature_vector` is partially written and therefore must not be used to
// feed the VAD RNN.
bool CheckSilenceComputeFeatures(
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
@ -53,7 +52,7 @@ class FeaturesExtractor {
PitchEstimator pitch_estimator_;
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
SpectralFeaturesExtractor spectral_features_extractor_;
PitchInfo pitch_info_48kHz_;
int pitch_period_48kHz_;
};
} // namespace rnn_vad

View File

@ -16,27 +16,23 @@
#include <numeric>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// Computes cross-correlation coefficients between |x| and |y| and writes them
// in |x_corr|. The lag values are in {0, ..., max_lag - 1}, where max_lag
// equals the size of |x_corr|.
// The |x| and |y| sub-arrays used to compute a cross-correlation coefficients
// for a lag l have both size "size of |x| - l" - i.e., the longest sub-array is
// used. |x| and |y| must have the same size.
void ComputeCrossCorrelation(
// Computes auto-correlation coefficients for `x` and writes them in
// `auto_corr`. The lag values are in {0, ..., max_lag - 1}, where max_lag
// equals the size of `auto_corr`.
void ComputeAutoCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float, kNumLpcCoefficients> x_corr) {
constexpr size_t max_lag = x_corr.size();
RTC_DCHECK_EQ(x.size(), y.size());
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
constexpr int max_lag = auto_corr.size();
RTC_DCHECK_LT(max_lag, x.size());
for (size_t lag = 0; lag < max_lag; ++lag) {
x_corr[lag] =
std::inner_product(x.begin(), x.end() - lag, y.begin() + lag, 0.f);
for (int lag = 0; lag < max_lag; ++lag) {
auto_corr[lag] =
std::inner_product(x.begin(), x.end() - lag, x.begin() + lag, 0.f);
}
}
@ -45,9 +41,13 @@ void DenoiseAutoCorrelation(
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
// Assume -40 dB white noise floor.
auto_corr[0] *= 1.0001f;
for (size_t i = 1; i < kNumLpcCoefficients; ++i) {
auto_corr[i] -= auto_corr[i] * (0.008f * i) * (0.008f * i);
}
// Hard-coded values obtained as
// [np.float32((0.008*0.008*i*i)) for i in range(1,5)].
auto_corr[1] -= auto_corr[1] * 0.000064f;
auto_corr[2] -= auto_corr[2] * 0.000256f;
auto_corr[3] -= auto_corr[3] * 0.000576f;
auto_corr[4] -= auto_corr[4] * 0.001024f;
static_assert(kNumLpcCoefficients == 5, "Update `auto_corr`.");
}
// Computes the initial inverse filter coefficients given the auto-correlation
@ -56,9 +56,9 @@ void ComputeInitialInverseFilterCoefficients(
rtc::ArrayView<const float, kNumLpcCoefficients> auto_corr,
rtc::ArrayView<float, kNumLpcCoefficients - 1> lpc_coeffs) {
float error = auto_corr[0];
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
for (int i = 0; i < kNumLpcCoefficients - 1; ++i) {
float reflection_coeff = 0.f;
for (size_t j = 0; j < i; ++j) {
for (int j = 0; j < i; ++j) {
reflection_coeff += lpc_coeffs[j] * auto_corr[i - j];
}
reflection_coeff += auto_corr[i + 1];
@ -72,7 +72,7 @@ void ComputeInitialInverseFilterCoefficients(
reflection_coeff /= -error;
// Update LPC coefficients and total error.
lpc_coeffs[i] = reflection_coeff;
for (size_t j = 0; j<(i + 1)>> 1; ++j) {
for (int j = 0; j < ((i + 1) >> 1); ++j) {
const float tmp1 = lpc_coeffs[j];
const float tmp2 = lpc_coeffs[i - 1 - j];
lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2;
@ -91,46 +91,49 @@ void ComputeAndPostProcessLpcCoefficients(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs) {
std::array<float, kNumLpcCoefficients> auto_corr;
ComputeCrossCorrelation(x, x, {auto_corr.data(), auto_corr.size()});
ComputeAutoCorrelation(x, auto_corr);
if (auto_corr[0] == 0.f) { // Empty frame.
std::fill(lpc_coeffs.begin(), lpc_coeffs.end(), 0);
return;
}
DenoiseAutoCorrelation({auto_corr.data(), auto_corr.size()});
DenoiseAutoCorrelation(auto_corr);
std::array<float, kNumLpcCoefficients - 1> lpc_coeffs_pre{};
ComputeInitialInverseFilterCoefficients(auto_corr, lpc_coeffs_pre);
// LPC coefficients post-processing.
// TODO(bugs.webrtc.org/9076): Consider removing these steps.
float c1 = 1.f;
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
c1 *= 0.9f;
lpc_coeffs_pre[i] *= c1;
}
const float c2 = 0.8f;
lpc_coeffs[0] = lpc_coeffs_pre[0] + c2;
lpc_coeffs[1] = lpc_coeffs_pre[1] + c2 * lpc_coeffs_pre[0];
lpc_coeffs[2] = lpc_coeffs_pre[2] + c2 * lpc_coeffs_pre[1];
lpc_coeffs[3] = lpc_coeffs_pre[3] + c2 * lpc_coeffs_pre[2];
lpc_coeffs[4] = c2 * lpc_coeffs_pre[3];
lpc_coeffs_pre[0] *= 0.9f;
lpc_coeffs_pre[1] *= 0.9f * 0.9f;
lpc_coeffs_pre[2] *= 0.9f * 0.9f * 0.9f;
lpc_coeffs_pre[3] *= 0.9f * 0.9f * 0.9f * 0.9f;
constexpr float kC = 0.8f;
lpc_coeffs[0] = lpc_coeffs_pre[0] + kC;
lpc_coeffs[1] = lpc_coeffs_pre[1] + kC * lpc_coeffs_pre[0];
lpc_coeffs[2] = lpc_coeffs_pre[2] + kC * lpc_coeffs_pre[1];
lpc_coeffs[3] = lpc_coeffs_pre[3] + kC * lpc_coeffs_pre[2];
lpc_coeffs[4] = kC * lpc_coeffs_pre[3];
static_assert(kNumLpcCoefficients == 5, "Update `lpc_coeffs(_pre)`.");
}
void ComputeLpResidual(
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
rtc::ArrayView<const float> x,
rtc::ArrayView<float> y) {
RTC_DCHECK_LT(kNumLpcCoefficients, x.size());
RTC_DCHECK_GT(x.size(), kNumLpcCoefficients);
RTC_DCHECK_EQ(x.size(), y.size());
std::array<float, kNumLpcCoefficients> input_chunk;
input_chunk.fill(0.f);
for (size_t i = 0; i < y.size(); ++i) {
const float sum = std::inner_product(input_chunk.begin(), input_chunk.end(),
lpc_coeffs.begin(), x[i]);
// Circular shift and add a new sample.
for (size_t j = kNumLpcCoefficients - 1; j > 0; --j)
input_chunk[j] = input_chunk[j - 1];
input_chunk[0] = x[i];
// Copy result.
y[i] = sum;
// The code below implements the following operation:
// y[i] = x[i] + dot_product({x[i], ..., x[i - kNumLpcCoefficients + 1]},
// lpc_coeffs)
// Edge case: i < kNumLpcCoefficients.
y[0] = x[0];
for (int i = 1; i < kNumLpcCoefficients; ++i) {
y[i] =
std::inner_product(x.crend() - i, x.crend(), lpc_coeffs.cbegin(), x[i]);
}
// Regular case.
auto last = x.crend();
for (int i = kNumLpcCoefficients; rtc::SafeLt(i, y.size()); ++i, --last) {
y[i] = std::inner_product(last - kNumLpcCoefficients, last,
lpc_coeffs.cbegin(), x[i]);
}
}

View File

@ -18,17 +18,17 @@
namespace webrtc {
namespace rnn_vad {
// LPC inverse filter length.
constexpr size_t kNumLpcCoefficients = 5;
// Linear predictive coding (LPC) inverse filter length.
constexpr int kNumLpcCoefficients = 5;
// Given a frame |x|, computes a post-processed version of LPC coefficients
// Given a frame `x`, computes a post-processed version of LPC coefficients
// tailored for pitch estimation.
void ComputeAndPostProcessLpcCoefficients(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs);
// Computes the LP residual for the input frame |x| and the LPC coefficients
// |lpc_coeffs|. |y| and |x| can point to the same array for in-place
// Computes the LP residual for the input frame `x` and the LPC coefficients
// `lpc_coeffs`. `y` and `x` can point to the same array for in-place
// computation.
void ComputeLpResidual(
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,

View File

@ -1,29 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
namespace webrtc {
namespace rnn_vad {
// Stores pitch period and gain information. The pitch gain measures the
// strength of the pitch (the higher, the stronger).
struct PitchInfo {
PitchInfo() : period(0), gain(0.f) {}
PitchInfo(int p, float g) : period(p), gain(g) {}
int period;
float gain;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_

View File

@ -18,38 +18,52 @@
namespace webrtc {
namespace rnn_vad {
PitchEstimator::PitchEstimator()
: pitch_buf_decimated_(kBufSize12kHz),
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
auto_corr_(kNumInvertedLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size());
}
PitchEstimator::PitchEstimator(const AvailableCpuFeatures& cpu_features)
: cpu_features_(cpu_features),
y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
pitch_buffer_12kHz_(kBufSize12kHz),
auto_correlation_12kHz_(kNumLags12kHz) {}
PitchEstimator::~PitchEstimator() = default;
PitchInfo PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
int PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
rtc::ArrayView<float, kBufSize12kHz> pitch_buffer_12kHz_view(
pitch_buffer_12kHz_.data(), kBufSize12kHz);
RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size());
rtc::ArrayView<float, kNumLags12kHz> auto_correlation_12kHz_view(
auto_correlation_12kHz_.data(), kNumLags12kHz);
RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
auto_correlation_12kHz_view.size());
// TODO(bugs.chromium.org/10480): Use `cpu_features_` to estimate pitch.
// Perform the initial pitch search at 12 kHz.
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
auto_corr_view_);
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
// Refine the pitch period estimation.
Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
auto_correlation_12kHz_view);
CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
pitch_buffer_12kHz_view, auto_correlation_12kHz_view, cpu_features_);
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
// Therefore, adapt the inverted lags in `pitch_candidates_inv_lags` from 12
// to 24 kHz.
pitch_candidates_inv_lags[0] *= 2;
pitch_candidates_inv_lags[1] *= 2;
size_t pitch_inv_lag_48kHz =
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags);
// Look for stronger harmonics to find the final pitch period and its gain.
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
return last_pitch_48kHz_;
pitch_periods.best *= 2;
pitch_periods.second_best *= 2;
// Refine the initial pitch period estimation from 12 kHz to 48 kHz.
// Pre-compute frame energies at 24 kHz.
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
y_energy_24kHz_.data(), kRefineNumLags24kHz);
RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view,
cpu_features_);
// Estimation at 48 kHz.
const int pitch_lag_48kHz = ComputePitchPeriod48kHz(
pitch_buffer, y_energy_24kHz_view, pitch_periods, cpu_features_);
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
pitch_buffer, y_energy_24kHz_view,
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
last_pitch_48kHz_, cpu_features_);
return last_pitch_48kHz_.period;
}
} // namespace rnn_vad

View File

@ -15,10 +15,11 @@
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
namespace rnn_vad {
@ -26,21 +27,25 @@ namespace rnn_vad {
// Pitch estimator.
class PitchEstimator {
public:
PitchEstimator();
explicit PitchEstimator(const AvailableCpuFeatures& cpu_features);
PitchEstimator(const PitchEstimator&) = delete;
PitchEstimator& operator=(const PitchEstimator&) = delete;
~PitchEstimator();
// Estimates the pitch period and gain. Returns the pitch estimation data for
// 48 kHz.
PitchInfo Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf);
// Returns the estimated pitch period at 48 kHz.
int Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer);
private:
PitchInfo last_pitch_48kHz_;
FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance);
float GetLastPitchStrengthForTesting() const {
return last_pitch_48kHz_.strength;
}
const AvailableCpuFeatures cpu_features_;
PitchInfo last_pitch_48kHz_{};
AutoCorrelationCalculator auto_corr_calculator_;
std::vector<float> pitch_buf_decimated_;
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
std::vector<float> auto_corr_;
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr_view_;
std::vector<float> y_energy_24kHz_;
std::vector<float> pitch_buffer_12kHz_;
std::vector<float> auto_correlation_12kHz_;
};
} // namespace rnn_vad

View File

@ -18,103 +18,81 @@
#include <numeric>
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/system/arch.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// Converts a lag to an inverted lag (only for 24kHz).
size_t GetInvertedLag(size_t lag) {
RTC_DCHECK_LE(lag, kMaxPitch24kHz);
return kMaxPitch24kHz - lag;
float ComputeAutoCorrelation(
int inverted_lag,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
const VectorMath& vector_math) {
RTC_DCHECK_LT(inverted_lag, kBufSize24kHz);
RTC_DCHECK_LT(inverted_lag, kRefineNumLags24kHz);
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
return vector_math.DotProduct(
pitch_buffer.subview(/*offset=*/kMaxPitch24kHz),
pitch_buffer.subview(inverted_lag, kFrameSize20ms24kHz));
}
float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,
size_t inv_lag,
size_t max_pitch_period) {
RTC_DCHECK_LT(inv_lag, pitch_buf.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
RTC_DCHECK_LE(inv_lag, max_pitch_period);
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
return std::inner_product(pitch_buf.begin() + max_pitch_period,
pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f);
}
// Computes a pseudo-interpolation offset for an estimated pitch period |lag| by
// looking at the auto-correlation coefficients in the neighborhood of |lag|.
// (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output
// is a lag in {-1, 0, +1}.
// TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it
// is relevant only if the spectral analysis works at a sample rate that is
// twice as that of the pitch buffer (not so important instead for the estimated
// pitch period feature fed into the RNN).
int GetPitchPseudoInterpolationOffset(size_t lag,
float prev_auto_corr,
float lag_auto_corr,
float next_auto_corr) {
const float& a = prev_auto_corr;
const float& b = lag_auto_corr;
const float& c = next_auto_corr;
int offset = 0;
if ((c - a) > 0.7f * (b - a)) {
offset = 1; // |c| is the largest auto-correlation coefficient.
} else if ((a - c) > 0.7f * (b - c)) {
offset = -1; // |a| is the largest auto-correlation coefficient.
// Given an auto-correlation coefficient `curr_auto_correlation` and its
// neighboring values `prev_auto_correlation` and `next_auto_correlation`
// computes a pseudo-interpolation offset to be applied to the pitch period
// associated to `curr`. The output is a lag in {-1, 0, +1}.
// TODO(bugs.webrtc.org/9076): Consider removing this method.
// `GetPitchPseudoInterpolationOffset()` it is relevant only if the spectral
// analysis works at a sample rate that is twice as that of the pitch buffer;
// In particular, it is not relevant for the estimated pitch period feature fed
// into the RNN.
int GetPitchPseudoInterpolationOffset(float prev_auto_correlation,
float curr_auto_correlation,
float next_auto_correlation) {
if ((next_auto_correlation - prev_auto_correlation) >
0.7f * (curr_auto_correlation - prev_auto_correlation)) {
return 1; // `next_auto_correlation` is the largest auto-correlation
// coefficient.
} else if ((prev_auto_correlation - next_auto_correlation) >
0.7f * (curr_auto_correlation - next_auto_correlation)) {
return -1; // `prev_auto_correlation` is the largest auto-correlation
// coefficient.
}
return offset;
return 0;
}
// Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The
// output sample rate is twice as that of |lag|.
size_t PitchPseudoInterpolationLagPitchBuf(
size_t lag,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
// Refines a pitch period `lag` encoded as lag with pseudo-interpolation. The
// output sample rate is twice as that of `lag`.
int PitchPseudoInterpolationLagPitchBuf(
int lag,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
const VectorMath& vector_math) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (lag > 0 && lag < kMaxPitch24kHz) {
const int inverted_lag = kMaxPitch24kHz - lag;
offset = GetPitchPseudoInterpolationOffset(
lag,
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1),
kMaxPitch24kHz),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag),
kMaxPitch24kHz),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1),
kMaxPitch24kHz));
ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer, vector_math),
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math),
ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer, vector_math));
}
return 2 * lag + offset;
}
// Refines a pitch period |inv_lag| encoded as inverted lag with
// pseudo-interpolation. The output sample rate is twice as that of
// |inv_lag|.
size_t PitchPseudoInterpolationInvLagAutoCorr(
size_t inv_lag,
rtc::ArrayView<const float> auto_corr) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) {
offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1],
auto_corr[inv_lag],
auto_corr[inv_lag - 1]);
}
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
// be subtracted since |inv_lag| is an inverted lag but offset is a lag.
return 2 * inv_lag + offset;
}
// Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when
// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
// looking for sub-harmonics.
// The values have been chosen to serve the following algorithm. Given the
// initial pitch period T, we examine whether one of its harmonics is the true
// fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of
// these harmonics, in addition to the pitch gain of itself, we choose one
// these harmonics, in addition to the pitch strength of itself, we choose one
// multiple of its pitch period, n*T/k, to validate it (by averaging their pitch
// gains). The multiplier n is chosen so that n*T/k is used only one time over
// all k. When for example k = 4, we should also expect a peak at 3*T/4. When
// k = 8 instead we don't want to look at 2*T/8, since we have already checked
// T/4 before. Instead, we look at T*3/8.
// strengths). The multiplier n is chosen so that n*T/k is used only one time
// over all k. When for example k = 4, we should also expect a peak at 3*T/4.
// When k = 8 instead we don't want to look at 2*T/8, since we have already
// checked T/4 before. Instead, we look at T*3/8.
// The array can be generate in Python as follows:
// from fractions import Fraction
// # Smallest positive integer not in X.
@ -131,96 +109,220 @@ size_t PitchPseudoInterpolationInvLagAutoCorr(
constexpr std::array<int, 14> kSubHarmonicMultipliers = {
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
struct Range {
int min;
int max;
};
// Number of analyzed pitches to the left(right) of a pitch candidate.
constexpr int kPitchNeighborhoodRadius = 2;
// Creates a pitch period interval centered in `inverted_lag` with hard-coded
// radius. Clipping is applied so that the interval is always valid for a 24 kHz
// pitch buffer.
Range CreateInvertedLagRange(int inverted_lag) {
return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0),
std::min(inverted_lag + kPitchNeighborhoodRadius,
kInitialNumLags24kHz - 1)};
}
constexpr int kNumPitchCandidates = 2; // Best and second best.
// Maximum number of analyzed pitch periods.
constexpr int kMaxPitchPeriods24kHz =
kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1);
// Collection of inverted lags.
class InvertedLagsIndex {
public:
InvertedLagsIndex() : num_entries_(0) {}
// Adds an inverted lag to the index. Cannot add more than
// `kMaxPitchPeriods24kHz` values.
void Append(int inverted_lag) {
RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz);
inverted_lags_[num_entries_++] = inverted_lag;
}
const int* data() const { return inverted_lags_.data(); }
int size() const { return num_entries_; }
private:
std::array<int, kMaxPitchPeriods24kHz> inverted_lags_;
int num_entries_;
};
// Computes the auto correlation coefficients for the inverted lags in the
// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending
// the inverted lags for the computed auto correlation values.
void ComputeAutoCorrelation(
Range inverted_lags,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
InvertedLagsIndex& inverted_lags_index,
const VectorMath& vector_math) {
// Check valid range.
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
// Trick to avoid zero initialization of `auto_correlation`.
// Needed by the pseudo-interpolation.
if (inverted_lags.min > 0) {
auto_correlation[inverted_lags.min - 1] = 0.f;
}
if (inverted_lags.max < kInitialNumLags24kHz - 1) {
auto_correlation[inverted_lags.max + 1] = 0.f;
}
// Check valid `inverted_lag` indexes.
RTC_DCHECK_GE(inverted_lags.min, 0);
RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz);
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
++inverted_lag) {
auto_correlation[inverted_lag] =
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math);
inverted_lags_index.Append(inverted_lag);
}
}
// Searches the strongest pitch period at 24 kHz and returns its inverted lag at
// 48 kHz.
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const int> inverted_lags,
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
const VectorMath& vector_math) {
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
int best_inverted_lag = 0; // Pitch period.
float best_numerator = -1.f; // Pitch strength numerator.
float best_denominator = 0.f; // Pitch strength denominator.
for (int inverted_lag : inverted_lags) {
// A pitch candidate must have positive correlation.
if (auto_correlation[inverted_lag] > 0.f) {
// Auto-correlation energy normalized by frame energy.
const float numerator =
auto_correlation[inverted_lag] * auto_correlation[inverted_lag];
const float denominator = y_energy[inverted_lag];
// Compare numerator/denominator ratios without using divisions.
if (numerator * best_denominator > best_numerator * denominator) {
best_inverted_lag = inverted_lag;
best_numerator = numerator;
best_denominator = denominator;
}
}
}
// Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a
// 48 kHz pitch period.
if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) {
// Cannot apply pseudo-interpolation at the boundaries.
return best_inverted_lag * 2;
}
int offset = GetPitchPseudoInterpolationOffset(
auto_correlation[best_inverted_lag + 1],
auto_correlation[best_inverted_lag],
auto_correlation[best_inverted_lag - 1]);
// TODO(bugs.webrtc.org/9076): When retraining, check if `offset` below should
// be subtracted since `inverted_lag` is an inverted lag but offset is a lag.
return 2 * best_inverted_lag + offset;
}
// Returns an alternative pitch period for `pitch_period` given a `multiplier`
// and a `divisor` of the period.
constexpr int GetAlternativePitchPeriod(int pitch_period,
int multiplier,
int divisor) {
RTC_DCHECK_GT(divisor, 0);
// Same as `round(multiplier * pitch_period / divisor)`.
return (2 * multiplier * pitch_period + divisor) / (2 * divisor);
}
// Returns true if the alternative pitch period is stronger than the initial one
// given the last estimated pitch and the value of `period_divisor` used to
// compute the alternative pitch period via `GetAlternativePitchPeriod()`.
bool IsAlternativePitchStrongerThanInitial(PitchInfo last,
PitchInfo initial,
PitchInfo alternative,
int period_divisor) {
// Initial pitch period candidate thresholds for a sample rate of 24 kHz.
// Computed as [5*k*k for k in range(16)].
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
static_assert(
kInitialPitchPeriodThresholds.size() == kSubHarmonicMultipliers.size(),
"");
RTC_DCHECK_GE(last.period, 0);
RTC_DCHECK_GE(initial.period, 0);
RTC_DCHECK_GE(alternative.period, 0);
RTC_DCHECK_GE(period_divisor, 2);
// Compute a term that lowers the threshold when `alternative.period` is close
// to the last estimated period `last.period` - i.e., pitch tracking.
float lower_threshold_term = 0.f;
if (std::abs(alternative.period - last.period) <= 1) {
// The candidate pitch period is within 1 sample from the last one.
// Make the candidate at `alternative.period` very easy to be accepted.
lower_threshold_term = last.strength;
} else if (std::abs(alternative.period - last.period) == 2 &&
initial.period >
kInitialPitchPeriodThresholds[period_divisor - 2]) {
// The candidate pitch period is 2 samples far from the last one and the
// period `initial.period` (from which `alternative.period` has been
// derived) is greater than a threshold. Make `alternative.period` easy to
// be accepted.
lower_threshold_term = 0.5f * last.strength;
}
// Set the threshold based on the strength of the initial estimate
// `initial.period`. Also reduce the chance of false positives caused by a
// bias towards high frequencies (originating from short-term correlations).
float threshold =
std::max(0.3f, 0.7f * initial.strength - lower_threshold_term);
if (alternative.period < 3 * kMinPitch24kHz) {
// High frequency.
threshold = std::max(0.4f, 0.85f * initial.strength - lower_threshold_term);
} else if (alternative.period < 2 * kMinPitch24kHz) {
// Even higher frequency.
threshold = std::max(0.5f, 0.9f * initial.strength - lower_threshold_term);
}
return alternative.strength > threshold;
}
} // namespace
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst) {
// TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
static_assert(2 * dst.size() == src.size(), "");
for (size_t i = 0; i < dst.size(); ++i) {
static_assert(2 * kBufSize12kHz == kBufSize24kHz, "");
for (int i = 0; i < kBufSize12kHz; ++i) {
dst[i] = src[2 * i];
}
}
float ComputePitchGainThreshold(int candidate_pitch_period,
int pitch_period_ratio,
int initial_pitch_period,
float initial_pitch_gain,
int prev_pitch_period,
float prev_pitch_gain) {
// Map arguments to more compact aliases.
const int& t1 = candidate_pitch_period;
const int& k = pitch_period_ratio;
const int& t0 = initial_pitch_period;
const float& g0 = initial_pitch_gain;
const int& t_prev = prev_pitch_period;
const float& g_prev = prev_pitch_gain;
// Validate input.
RTC_DCHECK_GE(t1, 0);
RTC_DCHECK_GE(k, 2);
RTC_DCHECK_GE(t0, 0);
RTC_DCHECK_GE(t_prev, 0);
// Compute a term that lowers the threshold when |t1| is close to the last
// estimated period |t_prev| - i.e., pitch tracking.
float lower_threshold_term = 0;
if (abs(t1 - t_prev) <= 1) {
// The candidate pitch period is within 1 sample from the previous one.
// Make the candidate at |t1| very easy to be accepted.
lower_threshold_term = g_prev;
} else if (abs(t1 - t_prev) == 2 &&
t0 > kInitialPitchPeriodThresholds[k - 2]) {
// The candidate pitch period is 2 samples far from the previous one and the
// period |t0| (from which |t1| has been derived) is greater than a
// threshold. Make |t1| easy to be accepted.
lower_threshold_term = 0.5f * g_prev;
}
// Set the threshold based on the gain of the initial estimate |t0|. Also
// reduce the chance of false positives caused by a bias towards high
// frequencies (originating from short-term correlations).
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
// High frequency.
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
} else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
// Even higher frequency.
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
}
return threshold;
}
void ComputeSlidingFrameSquareEnergies(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) {
float yy =
ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz);
yy_values[0] = yy;
for (size_t i = 1; i < yy_values.size(); ++i) {
RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz);
RTC_DCHECK_LE(i, kMaxPitch24kHz);
const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i];
const float new_coeff = pitch_buf[kMaxPitch24kHz - i];
yy -= old_coeff * old_coeff;
yy += new_coeff * new_coeff;
yy = std::max(0.f, yy);
yy_values[i] = yy;
void ComputeSlidingFrameSquareEnergies24kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
AvailableCpuFeatures cpu_features) {
VectorMath vector_math(cpu_features);
static_assert(kFrameSize20ms24kHz < kBufSize24kHz, "");
const auto frame_20ms_view = pitch_buffer.subview(0, kFrameSize20ms24kHz);
float yy = vector_math.DotProduct(frame_20ms_view, frame_20ms_view);
y_energy[0] = yy;
static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, "");
static_assert(kMaxPitch24kHz < kRefineNumLags24kHz, "");
for (int inverted_lag = 0; inverted_lag < kMaxPitch24kHz; ++inverted_lag) {
yy -= pitch_buffer[inverted_lag] * pitch_buffer[inverted_lag];
yy += pitch_buffer[inverted_lag + kFrameSize20ms24kHz] *
pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
yy = std::max(1.f, yy);
y_energy[inverted_lag + 1] = yy;
}
}
std::array<size_t, 2> FindBestPitchPeriods(
rtc::ArrayView<const float> auto_corr,
rtc::ArrayView<const float> pitch_buf,
size_t max_pitch_period) {
CandidatePitchPeriods ComputePitchPeriod12kHz(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
AvailableCpuFeatures cpu_features) {
static_assert(kMaxPitch12kHz > kNumLags12kHz, "");
static_assert(kMaxPitch12kHz < kBufSize12kHz, "");
// Stores a pitch candidate period and strength information.
struct PitchCandidate {
// Pitch period encoded as inverted lag.
size_t period_inverted_lag = 0;
int period_inverted_lag = 0;
// Pitch strength encoded as a ratio.
float strength_numerator = -1.f;
float strength_denominator = 0.f;
@ -232,25 +334,22 @@ std::array<size_t, 2> FindBestPitchPeriods(
}
};
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
const size_t frame_size = pitch_buf.size() - max_pitch_period;
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
float yy =
std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1,
pitch_buf.begin(), 1.f);
VectorMath vector_math(cpu_features);
static_assert(kFrameSize20ms12kHz + 1 < kBufSize12kHz, "");
const auto frame_view = pitch_buffer.subview(0, kFrameSize20ms12kHz + 1);
float denominator = 1.f + vector_math.DotProduct(frame_view, frame_view);
// Search best and second best pitches by looking at the scaled
// auto-correlation.
PitchCandidate candidate;
PitchCandidate best;
PitchCandidate second_best;
second_best.period_inverted_lag = 1;
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
for (int inverted_lag = 0; inverted_lag < kNumLags12kHz; ++inverted_lag) {
// A pitch candidate must have positive correlation.
if (auto_corr[inv_lag] > 0) {
candidate.period_inverted_lag = inv_lag;
candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag];
candidate.strength_denominator = yy;
if (auto_correlation[inverted_lag] > 0.f) {
PitchCandidate candidate{
inverted_lag,
auto_correlation[inverted_lag] * auto_correlation[inverted_lag],
denominator};
if (candidate.HasStrongerPitchThan(second_best)) {
if (candidate.HasStrongerPitchThan(best)) {
second_best = best;
@ -260,143 +359,154 @@ std::array<size_t, 2> FindBestPitchPeriods(
}
}
}
// Update |squared_energy_y| for the next inverted lag.
const float old_coeff = pitch_buf[inv_lag];
const float new_coeff = pitch_buf[inv_lag + frame_size];
yy -= old_coeff * old_coeff;
yy += new_coeff * new_coeff;
yy = std::max(0.f, yy);
// Update `squared_energy_y` for the next inverted lag.
const float y_old = pitch_buffer[inverted_lag];
const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms12kHz];
denominator -= y_old * y_old;
denominator += y_new * y_new;
denominator = std::max(0.f, denominator);
}
return {{best.period_inverted_lag, second_best.period_inverted_lag}};
return {best.period_inverted_lag, second_best.period_inverted_lag};
}
size_t RefinePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<const size_t, 2> inv_lags) {
// Compute the auto-correlation terms only for neighbors of the given pitch
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
// for a few lag values).
std::array<float, kNumInvertedLags24kHz> auto_corr;
auto_corr.fill(0.f); // Zeros become ignored lags in FindBestPitchPeriods().
auto is_neighbor = [](size_t i, size_t j) {
return ((i > j) ? (i - j) : (j - i)) <= 2;
};
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1]))
auto_corr[inv_lag] =
ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz);
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
CandidatePitchPeriods pitch_candidates,
AvailableCpuFeatures cpu_features) {
// Compute the auto-correlation terms only for neighbors of the two pitch
// candidates (best and second best).
std::array<float, kInitialNumLags24kHz> auto_correlation;
InvertedLagsIndex inverted_lags_index;
// Create two inverted lag ranges so that `r1` precedes `r2`.
const bool swap_candidates =
pitch_candidates.best > pitch_candidates.second_best;
const Range r1 = CreateInvertedLagRange(
swap_candidates ? pitch_candidates.second_best : pitch_candidates.best);
const Range r2 = CreateInvertedLagRange(
swap_candidates ? pitch_candidates.best : pitch_candidates.second_best);
// Check valid ranges.
RTC_DCHECK_LE(r1.min, r1.max);
RTC_DCHECK_LE(r2.min, r2.max);
// Check `r1` precedes `r2`.
RTC_DCHECK_LE(r1.min, r2.min);
RTC_DCHECK_LE(r1.max, r2.max);
VectorMath vector_math(cpu_features);
if (r1.max + 1 >= r2.min) {
// Overlapping or adjacent ranges.
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
inverted_lags_index, vector_math);
} else {
// Disjoint ranges.
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
inverted_lags_index, vector_math);
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
inverted_lags_index, vector_math);
}
// Find best pitch at 24 kHz.
const auto pitch_candidates_inv_lags = FindBestPitchPeriods(
{auto_corr.data(), auto_corr.size()},
{pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz);
const auto inv_lag = pitch_candidates_inv_lags[0]; // Refine the best.
// Pseudo-interpolation.
return PitchPseudoInterpolationInvLagAutoCorr(inv_lag, auto_corr);
return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
auto_correlation, y_energy, vector_math);
}
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
PitchInfo ComputeExtendedPitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz) {
PitchInfo last_pitch_48kHz,
AvailableCpuFeatures cpu_features) {
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
// Stores information for a refined pitch candidate.
struct RefinedPitchCandidate {
RefinedPitchCandidate() {}
RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
: period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
int period_24kHz;
// Pitch strength information.
float gain;
// Additional pitch strength information used for the final estimation of
// pitch gain.
float xy; // Cross-correlation.
float yy; // Auto-correlation.
int period;
float strength;
// Additional strength data used for the final pitch estimation.
float xy; // Auto-correlation.
float y_energy; // Energy of the sliding frame `y`.
};
// Initialize.
std::array<float, kMaxPitch24kHz + 1> yy_values;
ComputeSlidingFrameSquareEnergies(pitch_buf,
{yy_values.data(), yy_values.size()});
const float xx = yy_values[0];
// Helper lambdas.
const auto pitch_gain = [](float xy, float yy, float xx) {
RTC_DCHECK_LE(0.f, xx * yy);
return xy / std::sqrt(1.f + xx * yy);
const float x_energy = y_energy[kMaxPitch24kHz];
const auto pitch_strength = [x_energy](float xy, float y_energy) {
RTC_DCHECK_GE(x_energy * y_energy, 0.f);
return xy / std::sqrt(1.f + x_energy * y_energy);
};
// Initial pitch candidate gain.
VectorMath vector_math(cpu_features);
// Initialize the best pitch candidate with `initial_pitch_period_48kHz`.
RefinedPitchCandidate best_pitch;
best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
static_cast<int>(kMaxPitch24kHz - 1));
best_pitch.xy = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
best_pitch.yy = yy_values[best_pitch.period_24kHz];
best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx);
best_pitch.period =
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
best_pitch.xy = ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period,
pitch_buffer, vector_math);
best_pitch.y_energy = y_energy[kMaxPitch24kHz - best_pitch.period];
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy);
// Keep a copy of the initial pitch candidate.
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
// 24 kHz version of the last estimated pitch.
const PitchInfo last_pitch{last_pitch_48kHz.period / 2,
last_pitch_48kHz.strength};
// Store the initial pitch period information.
const size_t initial_pitch_period = best_pitch.period_24kHz;
const float initial_pitch_gain = best_pitch.gain;
// Given the initial pitch estimation, check lower periods (i.e., harmonics).
const auto alternative_period = [](int period, int k, int n) -> int {
RTC_DCHECK_GT(k, 0);
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
};
for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
++k) {
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
break;
// Find `max_period_divisor` such that the result of
// `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)`
// equals `kMinPitch24kHz`.
const int max_period_divisor =
(2 * initial_pitch.period) / (2 * kMinPitch24kHz - 1);
for (int period_divisor = 2; period_divisor <= max_period_divisor;
++period_divisor) {
PitchInfo alternative_pitch;
alternative_pitch.period = GetAlternativePitchPeriod(
initial_pitch.period, /*multiplier=*/1, period_divisor);
RTC_DCHECK_GE(alternative_pitch.period, kMinPitch24kHz);
// When looking at `alternative_pitch.period`, we also look at one of its
// sub-harmonics. `kSubHarmonicMultipliers` is used to know where to look.
// `period_divisor` == 2 is a special case since `dual_alternative_period`
// might be greater than the maximum pitch period.
int dual_alternative_period = GetAlternativePitchPeriod(
initial_pitch.period, kSubHarmonicMultipliers[period_divisor - 2],
period_divisor);
RTC_DCHECK_GT(dual_alternative_period, 0);
if (period_divisor == 2 && dual_alternative_period > kMaxPitch24kHz) {
dual_alternative_period = initial_pitch.period;
}
// When looking at |candidate_pitch_period|, we also look at one of its
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
// be greater than the maximum pitch period.
int candidate_pitch_secondary_period = alternative_period(
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
if (k == 2 &&
candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
candidate_pitch_secondary_period = initial_pitch_period;
}
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
RTC_DCHECK_NE(alternative_pitch.period, dual_alternative_period)
<< "The lower pitch period and the additional sub-harmonic must not "
"coincide.";
// Compute an auto-correlation score for the primary pitch candidate
// |candidate_pitch_period| by also looking at its possible sub-harmonic
// |candidate_pitch_secondary_period|.
float xy_primary_period = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz);
float xy_secondary_period = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(candidate_pitch_secondary_period),
kMaxPitch24kHz);
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
float yy = 0.5f * (yy_values[candidate_pitch_period] +
yy_values[candidate_pitch_secondary_period]);
float candidate_pitch_gain = pitch_gain(xy, yy, xx);
// `alternative_pitch.period` by also looking at its possible sub-harmonic
// `dual_alternative_period`.
const float xy_primary_period = ComputeAutoCorrelation(
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer, vector_math);
// TODO(webrtc:10480): Copy `xy_primary_period` if the secondary period is
// equal to the primary one.
const float xy_secondary_period = ComputeAutoCorrelation(
kMaxPitch24kHz - dual_alternative_period, pitch_buffer, vector_math);
const float xy = 0.5f * (xy_primary_period + xy_secondary_period);
const float yy =
0.5f * (y_energy[kMaxPitch24kHz - alternative_pitch.period] +
y_energy[kMaxPitch24kHz - dual_alternative_period]);
alternative_pitch.strength = pitch_strength(xy, yy);
// Maybe update best period.
float threshold = ComputePitchGainThreshold(
candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain,
prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain);
if (candidate_pitch_gain > threshold) {
best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy};
if (IsAlternativePitchStrongerThanInitial(
last_pitch, initial_pitch, alternative_pitch, period_divisor)) {
best_pitch = {alternative_pitch.period, alternative_pitch.strength, xy,
yy};
}
}
// Final pitch gain and period.
// Final pitch strength and period.
best_pitch.xy = std::max(0.f, best_pitch.xy);
RTC_DCHECK_LE(0.f, best_pitch.yy);
float final_pitch_gain = (best_pitch.yy <= best_pitch.xy)
? 1.f
: best_pitch.xy / (best_pitch.yy + 1.f);
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
RTC_DCHECK_LE(0.f, best_pitch.y_energy);
float final_pitch_strength =
(best_pitch.y_energy <= best_pitch.xy)
? 1.f
: best_pitch.xy / (best_pitch.y_energy + 1.f);
final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength);
int final_pitch_period_48kHz = std::max(
kMinPitch48kHz,
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
kMinPitch48kHz, PitchPseudoInterpolationLagPitchBuf(
best_pitch.period, pitch_buffer, vector_math));
return {final_pitch_period_48kHz, final_pitch_gain};
return {final_pitch_period_48kHz, final_pitch_strength};
}
} // namespace rnn_vad

View File

@ -14,10 +14,11 @@
#include <stddef.h>
#include <array>
#include <utility>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
namespace webrtc {
namespace rnn_vad {
@ -26,50 +27,86 @@ namespace rnn_vad {
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst);
// Computes a gain threshold for a candidate pitch period given the initial and
// the previous pitch period and gain estimates and the pitch period ratio used
// to derive the candidate pitch period from the initial period.
float ComputePitchGainThreshold(int candidate_pitch_period,
int pitch_period_ratio,
int initial_pitch_period,
float initial_pitch_gain,
int prev_pitch_period,
float prev_pitch_gain);
// Computes the sum of squared samples for every sliding frame in the pitch
// buffer. |yy_values| indexes are lags.
// Key concepts and keywords used below in this file.
//
// The pitch buffer is structured as depicted below:
// |.........|...........|
// a b
// The part on the left, named "a" contains the oldest samples, whereas "b" the
// most recent ones. The size of "a" corresponds to the maximum pitch period,
// that of "b" to the frame size (e.g., 16 ms and 20 ms respectively).
void ComputeSlidingFrameSquareEnergies(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values);
// The pitch estimation relies on a pitch buffer, which is an array-like data
// structured designed as follows:
//
// |....A....|.....B.....|
//
// The part on the left, named `A` contains the oldest samples, whereas `B`
// contains the most recent ones. The size of `A` corresponds to the maximum
// pitch period, that of `B` to the analysis frame size (e.g., 16 ms and 20 ms
// respectively).
//
// Pitch estimation is essentially based on the analysis of two 20 ms frames
// extracted from the pitch buffer. One frame, called `x`, is kept fixed and
// corresponds to `B` - i.e., the most recent 20 ms. The other frame, called
// `y`, is extracted from different parts of the buffer instead.
//
// The offset between `x` and `y` corresponds to a specific pitch period.
// For instance, if `y` is positioned at the beginning of the pitch buffer, then
// the cross-correlation between `x` and `y` can be used as an indication of the
// strength for the maximum pitch.
//
// Such an offset can be encoded in two ways:
// - As a lag, which is the index in the pitch buffer for the first item in `y`
// - As an inverted lag, which is the number of samples from the beginning of
// `x` and the end of `y`
//
// |---->| lag
// |....A....|.....B.....|
// |<--| inverted lag
// |.....y.....| `y` 20 ms frame
//
// The inverted lag has the advantage of being directly proportional to the
// corresponding pitch period.
// Given the auto-correlation coefficients stored according to
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
// and the second best pitch periods.
std::array<size_t, 2> FindBestPitchPeriods(
rtc::ArrayView<const float> auto_corr,
rtc::ArrayView<const float> pitch_buf,
size_t max_pitch_period);
// Computes the sum of squared samples for every sliding frame `y` in the pitch
// buffer. The indexes of `y_energy` are inverted lags.
void ComputeSlidingFrameSquareEnergies24kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
AvailableCpuFeatures cpu_features);
// Refines the pitch period estimation given the pitch buffer |pitch_buf| and
// the initial pitch period estimation |inv_lags|. Returns an inverted lag at
// 48 kHz.
size_t RefinePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<const size_t, 2> inv_lags);
// Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags.
struct CandidatePitchPeriods {
int best;
int second_best;
};
// Refines the pitch period estimation and compute the pitch gain. Returns the
// refined pitch estimation data at 48 kHz.
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
// Computes the candidate pitch periods at 12 kHz given a view on the 12 kHz
// pitch buffer and the auto-correlation values (having inverted lags as
// indexes).
CandidatePitchPeriods ComputePitchPeriod12kHz(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
AvailableCpuFeatures cpu_features);
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer,
// the energies for the sliding frames `y` at 24 kHz and the pitch period
// candidates at 24 kHz (encoded as inverted lag).
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
CandidatePitchPeriods pitch_candidates_24kHz,
AvailableCpuFeatures cpu_features);
struct PitchInfo {
int period;
float strength;
};
// Computes the pitch period at 48 kHz searching in an extended pitch range
// given a view on the 24 kHz pitch buffer, the energies for the sliding frames
// `y` at 24 kHz, the initial 48 kHz estimation (computed by
// `ComputePitchPeriod48kHz()`) and the last estimated pitch.
PitchInfo ComputeExtendedPitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz);
PitchInfo last_pitch_48kHz,
AvailableCpuFeatures cpu_features);
} // namespace rnn_vad
} // namespace webrtc

View File

@ -21,7 +21,7 @@ namespace webrtc {
namespace rnn_vad {
// Ring buffer for N arrays of type T each one with size S.
template <typename T, size_t S, size_t N>
template <typename T, int S, int N>
class RingBuffer {
static_assert(S > 0, "");
static_assert(N > 0, "");
@ -35,7 +35,7 @@ class RingBuffer {
~RingBuffer() = default;
// Set the ring buffer values to zero.
void Reset() { buffer_.fill(0); }
// Replace the least recently pushed array in the buffer with |new_values|.
// Replace the least recently pushed array in the buffer with `new_values`.
void Push(rtc::ArrayView<const T, S> new_values) {
std::memcpy(buffer_.data() + S * tail_, new_values.data(), S * sizeof(T));
tail_ += 1;
@ -43,13 +43,12 @@ class RingBuffer {
tail_ = 0;
}
// Return an array view onto the array with a given delay. A view on the last
// and least recently push array is returned when |delay| is 0 and N - 1
// and least recently push array is returned when `delay` is 0 and N - 1
// respectively.
rtc::ArrayView<const T, S> GetArrayView(size_t delay) const {
const int delay_int = static_cast<int>(delay);
RTC_DCHECK_LE(0, delay_int);
RTC_DCHECK_LT(delay_int, N);
int offset = tail_ - 1 - delay_int;
rtc::ArrayView<const T, S> GetArrayView(int delay) const {
RTC_DCHECK_LE(0, delay);
RTC_DCHECK_LT(delay, N);
int offset = tail_ - 1 - delay;
if (offset < 0)
offset += N;
return {buffer_.data() + S * offset, S};

View File

@ -10,415 +10,81 @@
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
#include "rtc_base/system/arch.h"
#if defined(WEBRTC_HAS_NEON)
#include <arm_neon.h>
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
#include <algorithm>
#include <array>
#include <cmath>
#include <numeric>
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "third_party/rnnoise/src/rnn_activations.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
namespace {
using rnnoise::kWeightsScale;
using rnnoise::kInputLayerInputSize;
using ::rnnoise::kInputLayerInputSize;
static_assert(kFeatureVectorSize == kInputLayerInputSize, "");
using rnnoise::kInputDenseBias;
using rnnoise::kInputDenseWeights;
using rnnoise::kInputLayerOutputSize;
static_assert(kInputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
"Increase kFullyConnectedLayersMaxUnits.");
using ::rnnoise::kInputDenseBias;
using ::rnnoise::kInputDenseWeights;
using ::rnnoise::kInputLayerOutputSize;
static_assert(kInputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
using rnnoise::kHiddenGruBias;
using rnnoise::kHiddenGruRecurrentWeights;
using rnnoise::kHiddenGruWeights;
using rnnoise::kHiddenLayerOutputSize;
static_assert(kHiddenLayerOutputSize <= kRecurrentLayersMaxUnits,
"Increase kRecurrentLayersMaxUnits.");
using ::rnnoise::kHiddenGruBias;
using ::rnnoise::kHiddenGruRecurrentWeights;
using ::rnnoise::kHiddenGruWeights;
using ::rnnoise::kHiddenLayerOutputSize;
static_assert(kHiddenLayerOutputSize <= kGruLayerMaxUnits, "");
using rnnoise::kOutputDenseBias;
using rnnoise::kOutputDenseWeights;
using rnnoise::kOutputLayerOutputSize;
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
"Increase kFullyConnectedLayersMaxUnits.");
using rnnoise::SigmoidApproximated;
using rnnoise::TansigApproximated;
inline float RectifiedLinearUnit(float x) {
return x < 0.f ? 0.f : x;
}
std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
std::vector<float> scaled_params(params.size());
std::transform(params.begin(), params.end(), scaled_params.begin(),
[](int8_t x) -> float {
return rnnoise::kWeightsScale * static_cast<float>(x);
});
return scaled_params;
}
// TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this
// function to improve setup time.
// Casts and scales |weights| and re-arranges the layout.
std::vector<float> GetPreprocessedFcWeights(
rtc::ArrayView<const int8_t> weights,
size_t output_size) {
if (output_size == 1) {
return GetScaledParams(weights);
}
// Transpose, scale and cast.
const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size);
std::vector<float> w(weights.size());
for (size_t o = 0; o < output_size; ++o) {
for (size_t i = 0; i < input_size; ++i) {
w[o * input_size + i] = rnnoise::kWeightsScale *
static_cast<float>(weights[i * output_size + o]);
}
}
return w;
}
constexpr size_t kNumGruGates = 3; // Update, reset, output.
// TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this
// function to improve setup time.
// Casts and scales |tensor_src| for a GRU layer and re-arranges the layout.
// It works both for weights, recurrent weights and bias.
std::vector<float> GetPreprocessedGruTensor(
rtc::ArrayView<const int8_t> tensor_src,
size_t output_size) {
// Transpose, cast and scale.
// |n| is the size of the first dimension of the 3-dim tensor |weights|.
const size_t n =
rtc::CheckedDivExact(tensor_src.size(), output_size * kNumGruGates);
const size_t stride_src = kNumGruGates * output_size;
const size_t stride_dst = n * output_size;
std::vector<float> tensor_dst(tensor_src.size());
for (size_t g = 0; g < kNumGruGates; ++g) {
for (size_t o = 0; o < output_size; ++o) {
for (size_t i = 0; i < n; ++i) {
tensor_dst[g * stride_dst + o * n + i] =
rnnoise::kWeightsScale *
static_cast<float>(
tensor_src[i * stride_src + g * output_size + o]);
}
}
}
return tensor_dst;
}
void ComputeGruUpdateResetGates(size_t input_size,
size_t output_size,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> state,
rtc::ArrayView<float> gate) {
for (size_t o = 0; o < output_size; ++o) {
gate[o] = bias[o];
for (size_t i = 0; i < input_size; ++i) {
gate[o] += input[i] * weights[o * input_size + i];
}
for (size_t s = 0; s < output_size; ++s) {
gate[o] += state[s] * recurrent_weights[o * output_size + s];
}
gate[o] = SigmoidApproximated(gate[o]);
}
}
void ComputeGruOutputGate(size_t input_size,
size_t output_size,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> state,
rtc::ArrayView<const float> reset,
rtc::ArrayView<float> gate) {
for (size_t o = 0; o < output_size; ++o) {
gate[o] = bias[o];
for (size_t i = 0; i < input_size; ++i) {
gate[o] += input[i] * weights[o * input_size + i];
}
for (size_t s = 0; s < output_size; ++s) {
gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
}
gate[o] = RectifiedLinearUnit(gate[o]);
}
}
// Gated recurrent unit (GRU) layer un-optimized implementation.
void ComputeGruLayerOutput(size_t input_size,
size_t output_size,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<const float> bias,
rtc::ArrayView<float> state) {
RTC_DCHECK_EQ(input_size, input.size());
// Stride and offset used to read parameter arrays.
const size_t stride_in = input_size * output_size;
const size_t stride_out = output_size * output_size;
// Update gate.
std::array<float, kRecurrentLayersMaxUnits> update;
ComputeGruUpdateResetGates(
input_size, output_size, weights.subview(0, stride_in),
recurrent_weights.subview(0, stride_out), bias.subview(0, output_size),
input, state, update);
// Reset gate.
std::array<float, kRecurrentLayersMaxUnits> reset;
ComputeGruUpdateResetGates(
input_size, output_size, weights.subview(stride_in, stride_in),
recurrent_weights.subview(stride_out, stride_out),
bias.subview(output_size, output_size), input, state, reset);
// Output gate.
std::array<float, kRecurrentLayersMaxUnits> output;
ComputeGruOutputGate(
input_size, output_size, weights.subview(2 * stride_in, stride_in),
recurrent_weights.subview(2 * stride_out, stride_out),
bias.subview(2 * output_size, output_size), input, state, reset, output);
// Update output through the update gates and update the state.
for (size_t o = 0; o < output_size; ++o) {
output[o] = update[o] * state[o] + (1.f - update[o]) * output[o];
state[o] = output[o];
}
}
// Fully connected layer un-optimized implementation.
void ComputeFullyConnectedLayerOutput(
size_t input_size,
size_t output_size,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
rtc::FunctionView<float(float)> activation_function,
rtc::ArrayView<float> output) {
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
for (size_t o = 0; o < output_size; ++o) {
output[o] = bias[o];
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
// |weights_| change the performance across different platforms.
for (size_t i = 0; i < input_size; ++i) {
output[o] += input[i] * weights[o * input_size + i];
}
output[o] = activation_function(output[o]);
}
}
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Fully connected layer SSE2 implementation.
void ComputeFullyConnectedLayerOutputSse2(
size_t input_size,
size_t output_size,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
rtc::FunctionView<float(float)> activation_function,
rtc::ArrayView<float> output) {
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
const size_t input_size_by_4 = input_size >> 2;
const size_t offset = input_size & ~3;
__m128 sum_wx_128;
const float* v = reinterpret_cast<const float*>(&sum_wx_128);
for (size_t o = 0; o < output_size; ++o) {
// Perform 128 bit vector operations.
sum_wx_128 = _mm_set1_ps(0);
const float* x_p = input.data();
const float* w_p = weights.data() + o * input_size;
for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
sum_wx_128 = _mm_add_ps(sum_wx_128,
_mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p)));
}
// Perform non-vector operations for any remaining items, sum up bias term
// and results from the vectorized code, and apply the activation function.
output[o] = activation_function(
std::inner_product(input.begin() + offset, input.end(),
weights.begin() + o * input_size + offset,
bias[o] + v[0] + v[1] + v[2] + v[3]));
}
}
#endif
using ::rnnoise::kOutputDenseBias;
using ::rnnoise::kOutputDenseWeights;
using ::rnnoise::kOutputLayerOutputSize;
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
} // namespace
FullyConnectedLayer::FullyConnectedLayer(
const size_t input_size,
const size_t output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
rtc::FunctionView<float(float)> activation_function,
Optimization optimization)
: input_size_(input_size),
output_size_(output_size),
bias_(GetScaledParams(bias)),
weights_(GetPreprocessedFcWeights(weights, output_size)),
activation_function_(activation_function),
optimization_(optimization) {
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
<< "Static over-allocation of fully-connected layers output vectors is "
"not sufficient.";
RTC_DCHECK_EQ(output_size_, bias_.size())
<< "Mismatching output size and bias terms array size.";
RTC_DCHECK_EQ(input_size_ * output_size_, weights_.size())
<< "Mismatching input-output size and weight coefficients array size.";
}
FullyConnectedLayer::~FullyConnectedLayer() = default;
rtc::ArrayView<const float> FullyConnectedLayer::GetOutput() const {
return rtc::ArrayView<const float>(output_.data(), output_size_);
}
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Optimization::kSse2:
ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input,
bias_, weights_,
activation_function_, output_);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Optimization::kNeon:
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
weights_, activation_function_, output_);
break;
#endif
default:
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
weights_, activation_function_, output_);
}
}
GatedRecurrentLayer::GatedRecurrentLayer(
const size_t input_size,
const size_t output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
const rtc::ArrayView<const int8_t> recurrent_weights,
Optimization optimization)
: input_size_(input_size),
output_size_(output_size),
bias_(GetPreprocessedGruTensor(bias, output_size)),
weights_(GetPreprocessedGruTensor(weights, output_size)),
recurrent_weights_(
GetPreprocessedGruTensor(recurrent_weights, output_size)),
optimization_(optimization) {
RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
<< "Static over-allocation of recurrent layers state vectors is not "
"sufficient.";
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
<< "Mismatching output size and bias terms array size.";
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
<< "Mismatching input-output size and weight coefficients array size.";
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
recurrent_weights_.size())
<< "Mismatching input-output size and recurrent weight coefficients array"
" size.";
Reset();
}
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
rtc::ArrayView<const float> GatedRecurrentLayer::GetOutput() const {
return rtc::ArrayView<const float>(state_.data(), output_size_);
}
void GatedRecurrentLayer::Reset() {
state_.fill(0.f);
}
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Optimization::kSse2:
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
recurrent_weights_, bias_, state_);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Optimization::kNeon:
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
recurrent_weights_, bias_, state_);
break;
#endif
default:
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
recurrent_weights_, bias_, state_);
}
}
RnnBasedVad::RnnBasedVad()
: input_layer_(kInputLayerInputSize,
kInputLayerOutputSize,
kInputDenseBias,
kInputDenseWeights,
TansigApproximated,
DetectOptimization()),
hidden_layer_(kInputLayerOutputSize,
kHiddenLayerOutputSize,
kHiddenGruBias,
kHiddenGruWeights,
kHiddenGruRecurrentWeights,
DetectOptimization()),
output_layer_(kHiddenLayerOutputSize,
kOutputLayerOutputSize,
kOutputDenseBias,
kOutputDenseWeights,
SigmoidApproximated,
DetectOptimization()) {
RnnVad::RnnVad(const AvailableCpuFeatures& cpu_features)
: input_(kInputLayerInputSize,
kInputLayerOutputSize,
kInputDenseBias,
kInputDenseWeights,
ActivationFunction::kTansigApproximated,
cpu_features,
/*layer_name=*/"FC1"),
hidden_(kInputLayerOutputSize,
kHiddenLayerOutputSize,
kHiddenGruBias,
kHiddenGruWeights,
kHiddenGruRecurrentWeights,
cpu_features,
/*layer_name=*/"GRU1"),
output_(kHiddenLayerOutputSize,
kOutputLayerOutputSize,
kOutputDenseBias,
kOutputDenseWeights,
ActivationFunction::kSigmoidApproximated,
// The output layer is just 24x1. The unoptimized code is faster.
NoAvailableCpuFeatures(),
/*layer_name=*/"FC2") {
// Input-output chaining size checks.
RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size())
RTC_DCHECK_EQ(input_.size(), hidden_.input_size())
<< "The input and the hidden layers sizes do not match.";
RTC_DCHECK_EQ(hidden_layer_.output_size(), output_layer_.input_size())
RTC_DCHECK_EQ(hidden_.size(), output_.input_size())
<< "The hidden and the output layers sizes do not match.";
}
RnnBasedVad::~RnnBasedVad() = default;
RnnVad::~RnnVad() = default;
void RnnBasedVad::Reset() {
hidden_layer_.Reset();
void RnnVad::Reset() {
hidden_.Reset();
}
float RnnBasedVad::ComputeVadProbability(
float RnnVad::ComputeVadProbability(
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
bool is_silence) {
if (is_silence) {
Reset();
return 0.f;
}
input_layer_.ComputeOutput(feature_vector);
hidden_layer_.ComputeOutput(input_layer_.GetOutput());
output_layer_.ComputeOutput(hidden_layer_.GetOutput());
const auto vad_output = output_layer_.GetOutput();
return vad_output[0];
input_.ComputeOutput(feature_vector);
hidden_.ComputeOutput(input_);
output_.ComputeOutput(hidden_);
RTC_DCHECK_EQ(output_.size(), 1);
return output_.data()[0];
}
} // namespace rnn_vad

View File

@ -18,106 +18,33 @@
#include <vector>
#include "api/array_view.h"
#include "api/function_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/system/arch.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
namespace webrtc {
namespace rnn_vad {
// Maximum number of units for a fully-connected layer. This value is used to
// over-allocate space for fully-connected layers output vectors (implemented as
// std::array). The value should equal the number of units of the largest
// fully-connected layer.
constexpr size_t kFullyConnectedLayersMaxUnits = 24;
// Maximum number of units for a recurrent layer. This value is used to
// over-allocate space for recurrent layers state vectors (implemented as
// std::array). The value should equal the number of units of the largest
// recurrent layer.
constexpr size_t kRecurrentLayersMaxUnits = 24;
// Fully-connected layer.
class FullyConnectedLayer {
// Recurrent network with hard-coded architecture and weights for voice activity
// detection.
class RnnVad {
public:
FullyConnectedLayer(size_t input_size,
size_t output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
rtc::FunctionView<float(float)> activation_function,
Optimization optimization);
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
~FullyConnectedLayer();
size_t input_size() const { return input_size_; }
size_t output_size() const { return output_size_; }
Optimization optimization() const { return optimization_; }
rtc::ArrayView<const float> GetOutput() const;
// Computes the fully-connected layer output.
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const size_t input_size_;
const size_t output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
rtc::FunctionView<float(float)> activation_function_;
// The output vector of a recurrent layer has length equal to |output_size_|.
// However, for efficiency, over-allocation is used.
std::array<float, kFullyConnectedLayersMaxUnits> output_;
const Optimization optimization_;
};
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
// activation functions for the update/reset and output gates respectively.
class GatedRecurrentLayer {
public:
GatedRecurrentLayer(size_t input_size,
size_t output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
rtc::ArrayView<const int8_t> recurrent_weights,
Optimization optimization);
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
~GatedRecurrentLayer();
size_t input_size() const { return input_size_; }
size_t output_size() const { return output_size_; }
Optimization optimization() const { return optimization_; }
rtc::ArrayView<const float> GetOutput() const;
explicit RnnVad(const AvailableCpuFeatures& cpu_features);
RnnVad(const RnnVad&) = delete;
RnnVad& operator=(const RnnVad&) = delete;
~RnnVad();
void Reset();
// Computes the recurrent layer output and updates the status.
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const size_t input_size_;
const size_t output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
const std::vector<float> recurrent_weights_;
// The state vector of a recurrent layer has length equal to |output_size_|.
// However, to avoid dynamic allocation, over-allocation is used.
std::array<float, kRecurrentLayersMaxUnits> state_;
const Optimization optimization_;
};
// Recurrent network based VAD.
class RnnBasedVad {
public:
RnnBasedVad();
RnnBasedVad(const RnnBasedVad&) = delete;
RnnBasedVad& operator=(const RnnBasedVad&) = delete;
~RnnBasedVad();
void Reset();
// Compute and returns the probability of voice (range: [0.0, 1.0]).
// Observes `feature_vector` and `is_silence`, updates the RNN and returns the
// current voice probability.
float ComputeVadProbability(
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
bool is_silence);
private:
FullyConnectedLayer input_layer_;
GatedRecurrentLayer hidden_layer_;
FullyConnectedLayer output_layer_;
FullyConnectedLayer input_;
GatedRecurrentLayer hidden_;
FullyConnectedLayer output_;
};
} // namespace rnn_vad

View File

@ -0,0 +1,104 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
#include <algorithm>
#include <numeric>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "third_party/rnnoise/src/rnn_activations.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
namespace {
std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
std::vector<float> scaled_params(params.size());
std::transform(params.begin(), params.end(), scaled_params.begin(),
[](int8_t x) -> float {
return ::rnnoise::kWeightsScale * static_cast<float>(x);
});
return scaled_params;
}
// TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this
// function to improve setup time.
// Casts and scales `weights` and re-arranges the layout.
std::vector<float> PreprocessWeights(rtc::ArrayView<const int8_t> weights,
int output_size) {
if (output_size == 1) {
return GetScaledParams(weights);
}
// Transpose, scale and cast.
const int input_size = rtc::CheckedDivExact(
rtc::dchecked_cast<int>(weights.size()), output_size);
std::vector<float> w(weights.size());
for (int o = 0; o < output_size; ++o) {
for (int i = 0; i < input_size; ++i) {
w[o * input_size + i] = rnnoise::kWeightsScale *
static_cast<float>(weights[i * output_size + o]);
}
}
return w;
}
rtc::FunctionView<float(float)> GetActivationFunction(
ActivationFunction activation_function) {
switch (activation_function) {
case ActivationFunction::kTansigApproximated:
return ::rnnoise::TansigApproximated;
case ActivationFunction::kSigmoidApproximated:
return ::rnnoise::SigmoidApproximated;
}
}
} // namespace
FullyConnectedLayer::FullyConnectedLayer(
const int input_size,
const int output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
ActivationFunction activation_function,
const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name)
: input_size_(input_size),
output_size_(output_size),
bias_(GetScaledParams(bias)),
weights_(PreprocessWeights(weights, output_size)),
vector_math_(cpu_features),
activation_function_(GetActivationFunction(activation_function)) {
RTC_DCHECK_LE(output_size_, kFullyConnectedLayerMaxUnits)
<< "Insufficient FC layer over-allocation (" << layer_name << ").";
RTC_DCHECK_EQ(output_size_, bias_.size())
<< "Mismatching output size and bias terms array size (" << layer_name
<< ").";
RTC_DCHECK_EQ(input_size_ * output_size_, weights_.size())
<< "Mismatching input-output size and weight coefficients array size ("
<< layer_name << ").";
}
FullyConnectedLayer::~FullyConnectedLayer() = default;
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
RTC_DCHECK_EQ(input.size(), input_size_);
rtc::ArrayView<const float> weights(weights_);
for (int o = 0; o < output_size_; ++o) {
output_[o] = activation_function_(
bias_[o] + vector_math_.DotProduct(
input, weights.subview(o * input_size_, input_size_)));
}
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,72 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
#include <array>
#include <vector>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "api/function_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
namespace webrtc {
namespace rnn_vad {
// Activation function for a neural network cell.
enum class ActivationFunction { kTansigApproximated, kSigmoidApproximated };
// Maximum number of units for an FC layer.
constexpr int kFullyConnectedLayerMaxUnits = 24;
// Fully-connected layer with a custom activation function which owns the output
// buffer.
class FullyConnectedLayer {
public:
// Ctor. `output_size` cannot be greater than `kFullyConnectedLayerMaxUnits`.
FullyConnectedLayer(int input_size,
int output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
ActivationFunction activation_function,
const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name);
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
~FullyConnectedLayer();
// Returns the size of the input vector.
int input_size() const { return input_size_; }
// Returns the pointer to the first element of the output buffer.
const float* data() const { return output_.data(); }
// Returns the size of the output buffer.
int size() const { return output_size_; }
// Computes the fully-connected layer output.
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const int input_size_;
const int output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
const VectorMath vector_math_;
rtc::FunctionView<float(float)> activation_function_;
// Over-allocated array with size equal to `output_size_`.
std::array<float, kFullyConnectedLayerMaxUnits> output_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_

View File

@ -0,0 +1,198 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "third_party/rnnoise/src/rnn_activations.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr int kNumGruGates = 3; // Update, reset, output.
std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
int output_size) {
// Transpose, cast and scale.
// `n` is the size of the first dimension of the 3-dim tensor `weights`.
const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
output_size * kNumGruGates);
const int stride_src = kNumGruGates * output_size;
const int stride_dst = n * output_size;
std::vector<float> tensor_dst(tensor_src.size());
for (int g = 0; g < kNumGruGates; ++g) {
for (int o = 0; o < output_size; ++o) {
for (int i = 0; i < n; ++i) {
tensor_dst[g * stride_dst + o * n + i] =
::rnnoise::kWeightsScale *
static_cast<float>(
tensor_src[i * stride_src + g * output_size + o]);
}
}
}
return tensor_dst;
}
// Computes the output for the update or the reset gate.
// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
// - `g`: output gate vector
// - `W`: weights matrix
// - `i`: input vector
// - `R`: recurrent weights matrix
// - `s`: state gate vector
// - `b`: bias vector
void ComputeUpdateResetGate(int input_size,
int output_size,
const VectorMath& vector_math,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> state,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<float> gate) {
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_EQ(state.size(), output_size);
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated.
for (int o = 0; o < output_size; ++o) {
float x = bias[o];
x += vector_math.DotProduct(input,
weights.subview(o * input_size, input_size));
x += vector_math.DotProduct(
state, recurrent_weights.subview(o * output_size, output_size));
gate[o] = ::rnnoise::SigmoidApproximated(x);
}
}
// Computes the output for the state gate.
// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
// - `s'`: output state gate vector
// - `s`: previous state gate vector
// - `u`: update gate vector
// - `W`: weights matrix
// - `i`: input vector
// - `R`: recurrent weights matrix
// - `r`: reset gate vector
// - `b`: bias vector
// - `.*` element-wise product
void ComputeStateGate(int input_size,
int output_size,
const VectorMath& vector_math,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> update,
rtc::ArrayView<const float> reset,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<float> state) {
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated.
RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated.
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
RTC_DCHECK_EQ(state.size(), output_size);
std::array<float, kGruLayerMaxUnits> reset_x_state;
for (int o = 0; o < output_size; ++o) {
reset_x_state[o] = state[o] * reset[o];
}
for (int o = 0; o < output_size; ++o) {
float x = bias[o];
x += vector_math.DotProduct(input,
weights.subview(o * input_size, input_size));
x += vector_math.DotProduct(
{reset_x_state.data(), static_cast<size_t>(output_size)},
recurrent_weights.subview(o * output_size, output_size));
state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
}
}
} // namespace
GatedRecurrentLayer::GatedRecurrentLayer(
const int input_size,
const int output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
const rtc::ArrayView<const int8_t> recurrent_weights,
const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name)
: input_size_(input_size),
output_size_(output_size),
bias_(PreprocessGruTensor(bias, output_size)),
weights_(PreprocessGruTensor(weights, output_size)),
recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
vector_math_(cpu_features) {
RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
<< "Insufficient GRU layer over-allocation (" << layer_name << ").";
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
<< "Mismatching output size and bias terms array size (" << layer_name
<< ").";
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
<< "Mismatching input-output size and weight coefficients array size ("
<< layer_name << ").";
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
recurrent_weights_.size())
<< "Mismatching input-output size and recurrent weight coefficients array"
" size ("
<< layer_name << ").";
Reset();
}
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
void GatedRecurrentLayer::Reset() {
state_.fill(0.f);
}
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
RTC_DCHECK_EQ(input.size(), input_size_);
// The tensors below are organized as a sequence of flattened tensors for the
// `update`, `reset` and `state` gates.
rtc::ArrayView<const float> bias(bias_);
rtc::ArrayView<const float> weights(weights_);
rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
// Strides to access to the flattened tensors for a specific gate.
const int stride_weights = input_size_ * output_size_;
const int stride_recurrent_weights = output_size_ * output_size_;
rtc::ArrayView<float> state(state_.data(), output_size_);
// Update gate.
std::array<float, kGruLayerMaxUnits> update;
ComputeUpdateResetGate(
input_size_, output_size_, vector_math_, input, state,
bias.subview(0, output_size_), weights.subview(0, stride_weights),
recurrent_weights.subview(0, stride_recurrent_weights), update);
// Reset gate.
std::array<float, kGruLayerMaxUnits> reset;
ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
bias.subview(output_size_, output_size_),
weights.subview(stride_weights, stride_weights),
recurrent_weights.subview(stride_recurrent_weights,
stride_recurrent_weights),
reset);
// State gate.
ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
reset, bias.subview(2 * output_size_, output_size_),
weights.subview(2 * stride_weights, stride_weights),
recurrent_weights.subview(2 * stride_recurrent_weights,
stride_recurrent_weights),
state);
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,70 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
#include <array>
#include <vector>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
namespace webrtc {
namespace rnn_vad {
// Maximum number of units for a GRU layer.
constexpr int kGruLayerMaxUnits = 24;
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
// activation functions for the update/reset and output gates respectively.
class GatedRecurrentLayer {
public:
// Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
GatedRecurrentLayer(int input_size,
int output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
rtc::ArrayView<const int8_t> recurrent_weights,
const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name);
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
~GatedRecurrentLayer();
// Returns the size of the input vector.
int input_size() const { return input_size_; }
// Returns the pointer to the first element of the output buffer.
const float* data() const { return state_.data(); }
// Returns the size of the output buffer.
int size() const { return output_size_; }
// Resets the GRU state.
void Reset();
// Computes the recurrent layer output and updates the status.
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const int input_size_;
const int output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
const std::vector<float> recurrent_weights_;
const VectorMath vector_math_;
// Over-allocated array with size equal to `output_size_`.
std::array<float, kGruLayerMaxUnits> state_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_

View File

@ -29,7 +29,7 @@ namespace rnn_vad {
// values are written at the end of the buffer.
// The class also provides a view on the most recent M values, where 0 < M <= S
// and by default M = N.
template <typename T, size_t S, size_t N, size_t M = N>
template <typename T, int S, int N, int M = N>
class SequenceBuffer {
static_assert(N <= S,
"The new chunk size cannot be larger than the sequence buffer "
@ -45,8 +45,8 @@ class SequenceBuffer {
SequenceBuffer(const SequenceBuffer&) = delete;
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
~SequenceBuffer() = default;
size_t size() const { return S; }
size_t chunks_size() const { return N; }
int size() const { return S; }
int chunks_size() const { return N; }
// Sets the sequence buffer values to zero.
void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); }
// Returns a view on the whole buffer.

View File

@ -16,6 +16,7 @@
#include <numeric>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
@ -32,11 +33,11 @@ void UpdateCepstralDifferenceStats(
RTC_DCHECK(sym_matrix_buf);
// Compute the new cepstral distance stats.
std::array<float, kCepstralCoeffsHistorySize - 1> distances;
for (size_t i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
const size_t delay = i + 1;
for (int i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
const int delay = i + 1;
auto old_cepstral_coeffs = ring_buf.GetArrayView(delay);
distances[i] = 0.f;
for (size_t k = 0; k < kNumBands; ++k) {
for (int k = 0; k < kNumBands; ++k) {
const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k];
distances[i] += c * c;
}
@ -48,9 +49,9 @@ void UpdateCepstralDifferenceStats(
// Computes the first half of the Vorbis window.
std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
float scaling = 1.f) {
constexpr size_t kHalfSize = kFrameSize20ms24kHz / 2;
constexpr int kHalfSize = kFrameSize20ms24kHz / 2;
std::array<float, kHalfSize> half_window{};
for (size_t i = 0; i < kHalfSize; ++i) {
for (int i = 0; i < kHalfSize; ++i) {
half_window[i] =
scaling *
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
@ -71,8 +72,8 @@ void ComputeWindowedForwardFft(
RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
// Apply windowing.
auto in = fft_input_buffer->GetView();
for (size_t i = 0, j = kFrameSize20ms24kHz - 1; i < half_window.size();
++i, --j) {
for (int i = 0, j = kFrameSize20ms24kHz - 1;
rtc::SafeLt(i, half_window.size()); ++i, --j) {
in[i] = frame[i] * half_window[i];
in[j] = frame[j] * half_window[i];
}
@ -162,7 +163,7 @@ void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
RTC_DCHECK_EQ(average.size(), first_derivative.size());
RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
RTC_DCHECK_LE(average.size(), curr.size());
for (size_t i = 0; i < average.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, average.size()); ++i) {
// Average, kernel: [1, 1, 1].
average[i] = curr[i] + prev1[i] + prev2[i];
// First derivative, kernel: [1, 0, - 1].
@ -178,7 +179,7 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
bands_cross_corr_);
// Normalize.
for (size_t i = 0; i < bands_cross_corr_.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, bands_cross_corr_.size()); ++i) {
bands_cross_corr_[i] =
bands_cross_corr_[i] /
std::sqrt(0.001f + reference_frame_bands_energy_[i] *
@ -194,9 +195,9 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
float SpectralFeaturesExtractor::ComputeVariability() const {
// Compute cepstral variability score.
float variability = 0.f;
for (size_t delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
for (int delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
float min_dist = std::numeric_limits<float>::max();
for (size_t delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
for (int delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
if (delay1 == delay2) // The distance would be 0.
continue;
min_dist =

View File

@ -15,6 +15,7 @@
#include <cstddef>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
@ -22,7 +23,7 @@ namespace {
// Weights for each FFT coefficient for each Opus band (Nyquist frequency
// excluded). The size of each band is specified in
// |kOpusScaleNumBins24kHz20ms|.
// `kOpusScaleNumBins24kHz20ms`.
constexpr std::array<float, kFrameSize20ms24kHz / 2> kOpusBandWeights24kHz20ms =
{{
0.f, 0.25f, 0.5f, 0.75f, // Band 0
@ -105,9 +106,9 @@ void SpectralCorrelator::ComputeCrossCorrelation(
RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
size_t k = 0; // Next Fourier coefficient index.
int k = 0; // Next Fourier coefficient index.
cross_corr[0] = 0.f;
for (size_t i = 0; i < kOpusBands24kHz - 1; ++i) {
for (int i = 0; i < kOpusBands24kHz - 1; ++i) {
cross_corr[i + 1] = 0.f;
for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
@ -137,11 +138,11 @@ void ComputeSmoothedLogMagnitudeSpectrum(
return x;
};
// Smoothing over the bands for which the band energy is defined.
for (size_t i = 0; i < bands_energy.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, bands_energy.size()); ++i) {
log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
}
// Smoothing over the remaining bands (zero energy).
for (size_t i = bands_energy.size(); i < kNumBands; ++i) {
for (int i = bands_energy.size(); i < kNumBands; ++i) {
log_bands_energy[i] = smooth(kLogOneByHundred);
}
}
@ -149,8 +150,8 @@ void ComputeSmoothedLogMagnitudeSpectrum(
std::array<float, kNumBands * kNumBands> ComputeDctTable() {
std::array<float, kNumBands * kNumBands> dct_table;
const double k = std::sqrt(0.5);
for (size_t i = 0; i < kNumBands; ++i) {
for (size_t j = 0; j < kNumBands; ++j)
for (int i = 0; i < kNumBands; ++i) {
for (int j = 0; j < kNumBands; ++j)
dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
dct_table[i * kNumBands] *= k;
}
@ -173,9 +174,9 @@ void ComputeDct(rtc::ArrayView<const float> in,
RTC_DCHECK_LE(in.size(), kNumBands);
RTC_DCHECK_LE(1, out.size());
RTC_DCHECK_LE(out.size(), in.size());
for (size_t i = 0; i < out.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, out.size()); ++i) {
out[i] = 0.f;
for (size_t j = 0; j < in.size(); ++j) {
for (int j = 0; rtc::SafeLt(j, in.size()); ++j) {
out[i] += in[j] * dct_table[j * kNumBands + i];
}
// TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.

View File

@ -25,7 +25,7 @@ namespace rnn_vad {
// At a sample rate of 24 kHz, the last 3 Opus bands are beyond the Nyquist
// frequency. However, band #19 gets the contributions from band #18 because
// of the symmetric triangular filter with peak response at 12 kHz.
constexpr size_t kOpusBands24kHz = 20;
constexpr int kOpusBands24kHz = 20;
static_assert(kOpusBands24kHz < kNumBands,
"The number of bands at 24 kHz must be less than those defined "
"in the Opus scale at 48 kHz.");
@ -50,8 +50,8 @@ class SpectralCorrelator {
~SpectralCorrelator();
// Computes the band-wise spectral auto-correlations.
// |x| must:
// - have size equal to |kFrameSize20ms24kHz|;
// `x` must:
// - have size equal to `kFrameSize20ms24kHz`;
// - be encoded as vectors of interleaved real-complex FFT coefficients
// where x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
void ComputeAutoCorrelation(
@ -59,8 +59,8 @@ class SpectralCorrelator {
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const;
// Computes the band-wise spectral cross-correlations.
// |x| and |y| must:
// - have size equal to |kFrameSize20ms24kHz|;
// `x` and `y` must:
// - have size equal to `kFrameSize20ms24kHz`;
// - be encoded as vectors of interleaved real-complex FFT coefficients where
// x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
void ComputeCrossCorrelation(
@ -82,12 +82,12 @@ void ComputeSmoothedLogMagnitudeSpectrum(
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
// spectral_features.cc. Creates a DCT table for arrays having size equal to
// |kNumBands|. Declared here for unit testing.
// `kNumBands`. Declared here for unit testing.
std::array<float, kNumBands * kNumBands> ComputeDctTable();
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
// spectral_features.cc. Computes DCT for |in| given a pre-computed DCT table.
// In-place computation is not allowed and |out| can be smaller than |in| in
// spectral_features.cc. Computes DCT for `in` given a pre-computed DCT table.
// In-place computation is not allowed and `out` can be smaller than `in` in
// order to only compute the first DCT coefficients. Declared here for unit
// testing.
void ComputeDct(rtc::ArrayView<const float> in,

View File

@ -18,6 +18,7 @@
#include "api/array_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
@ -29,7 +30,7 @@ namespace rnn_vad {
// removed when one of the two corresponding items that have been compared is
// removed from the ring buffer. It is assumed that the comparison is symmetric
// and that comparing an item with itself is not needed.
template <typename T, size_t S>
template <typename T, int S>
class SymmetricMatrixBuffer {
static_assert(S > 2, "");
@ -45,9 +46,9 @@ class SymmetricMatrixBuffer {
buf_.fill(0);
}
// Pushes the results from the comparison between the most recent item and
// those that are still in the ring buffer. The first element in |values| must
// those that are still in the ring buffer. The first element in `values` must
// correspond to the comparison between the most recent item and the second
// most recent one in the ring buffer, whereas the last element in |values|
// most recent one in the ring buffer, whereas the last element in `values`
// must correspond to the comparison between the most recent item and the
// oldest one in the ring buffer.
void Push(rtc::ArrayView<T, S - 1> values) {
@ -55,19 +56,19 @@ class SymmetricMatrixBuffer {
// column left.
std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T));
// Copy new values in the last column in the right order.
for (size_t i = 0; i < values.size(); ++i) {
const size_t index = (S - 1 - i) * (S - 1) - 1;
RTC_DCHECK_LE(static_cast<size_t>(0), index);
for (int i = 0; rtc::SafeLt(i, values.size()); ++i) {
const int index = (S - 1 - i) * (S - 1) - 1;
RTC_DCHECK_GE(index, 0);
RTC_DCHECK_LT(index, buf_.size());
buf_[index] = values[i];
}
}
// Reads the value that corresponds to comparison of two items in the ring
// buffer having delay |delay1| and |delay2|. The two arguments must not be
// buffer having delay `delay1` and `delay2`. The two arguments must not be
// equal and both must be in {0, ..., S - 1}.
T GetValue(size_t delay1, size_t delay2) const {
int row = S - 1 - static_cast<int>(delay1);
int col = S - 1 - static_cast<int>(delay2);
T GetValue(int delay1, int delay2) const {
int row = S - 1 - delay1;
int col = S - 1 - delay2;
RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed.";
if (row > col)
std::swap(row, col); // Swap to access the upper-right triangular part.

View File

@ -10,21 +10,61 @@
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
#include "absl/strings/string_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/system/arch.h"
#include "system_wrappers/include/cpu_features_wrapper.h"
#include "rtc_base/numerics/safe_compare.h"
#include "test/gtest.h"
#include "test/testsupport/file_utils.h"
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
using ReaderPairType =
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>;
// File reader for binary files that contain a sequence of values with
// arithmetic type `T`. The values of type `T` that are read are cast to float.
template <typename T>
class FloatFileReader : public FileReader {
public:
static_assert(std::is_arithmetic<T>::value, "");
explicit FloatFileReader(absl::string_view filename)
: is_(std::string(filename), std::ios::binary | std::ios::ate),
size_(is_.tellg() / sizeof(T)) {
RTC_CHECK(is_);
SeekBeginning();
}
FloatFileReader(const FloatFileReader&) = delete;
FloatFileReader& operator=(const FloatFileReader&) = delete;
~FloatFileReader() = default;
int size() const override { return size_; }
bool ReadChunk(rtc::ArrayView<float> dst) override {
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
if (std::is_same<T, float>::value) {
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
} else {
buffer_.resize(dst.size());
is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
[](const T& v) -> float { return static_cast<float>(v); });
}
return is_.gcount() == bytes_to_read;
}
bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); }
void SeekBeginning() override { is_.seekg(0, is_.beg); }
private:
std::ifstream is_;
const int size_;
std::vector<T> buffer_;
};
} // namespace
@ -33,7 +73,7 @@ using webrtc::test::ResourcePath;
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed) {
ASSERT_EQ(expected.size(), computed.size());
for (size_t i = 0; i < expected.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
SCOPED_TRACE(i);
EXPECT_FLOAT_EQ(expected[i], computed[i]);
}
@ -43,87 +83,61 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
float tolerance) {
ASSERT_EQ(expected.size(), computed.size());
for (size_t i = 0; i < expected.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
SCOPED_TRACE(i);
EXPECT_NEAR(expected[i], computed[i], tolerance);
}
}
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
CreatePcmSamplesReader(const size_t frame_length) {
auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
frame_length);
// The last incomplete frame is ignored.
return {std::move(ptr), ptr->data_length() / frame_length};
std::unique_ptr<FileReader> CreatePcmSamplesReader() {
return std::make_unique<FloatFileReader<int16_t>>(
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
"pcm"));
}
ReaderPairType CreatePitchBuffer24kHzReader() {
constexpr size_t cols = 864;
auto ptr = std::make_unique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
ChunksFileReader CreatePitchBuffer24kHzReader() {
auto reader = std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath(
"audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz);
return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)};
}
ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
constexpr size_t num_lp_residual_coeffs = 864;
auto ptr = std::make_unique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
num_lp_residual_coeffs);
return {std::move(ptr),
rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
ChunksFileReader CreateLpResidualAndPitchInfoReader() {
constexpr int kPitchInfoSize = 2; // Pitch period and strength.
constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
auto reader = std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath(
"audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize);
return {kChunkSize, num_chunks, std::move(reader)};
}
ReaderPairType CreateVadProbsReader() {
auto ptr = std::make_unique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
return {std::move(ptr), ptr->data_length()};
std::unique_ptr<FileReader> CreateGruInputReader() {
return std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in",
"dat"));
}
std::unique_ptr<FileReader> CreateVadProbsReader() {
return std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
"dat"));
}
PitchTestData::PitchTestData() {
BinaryFileReader<float> test_data_reader(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
static_cast<size_t>(1396));
test_data_reader.ReadChunk(test_data_);
FloatFileReader<float> reader(
/*filename=*/ResourcePath(
"audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
reader.ReadChunk(pitch_buffer_24k_);
reader.ReadChunk(square_energies_24k_);
reader.ReadChunk(auto_correlation_12k_);
// Reverse the order of the squared energy values.
// Required after the WebRTC CL 191703 which switched to forward computation.
std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
}
PitchTestData::~PitchTestData() = default;
rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
const {
return {test_data_.data(), kBufSize24kHz};
}
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
PitchTestData::GetPitchBufSquareEnergiesView() const {
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
}
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
kNumPitchBufAutoCorrCoeffs};
}
bool IsOptimizationAvailable(Optimization optimization) {
switch (optimization) {
case Optimization::kSse2:
#if defined(WEBRTC_ARCH_X86_FAMILY)
return GetCPUInfo(kSSE2) != 0;
#else
return false;
#endif
case Optimization::kNeon:
#if defined(WEBRTC_HAS_NEON)
return true;
#else
return false;
#endif
case Optimization::kNone:
return true;
}
}
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -11,23 +11,19 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#include <algorithm>
#include <array>
#include <fstream>
#include <limits>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
namespace test {
constexpr float kFloatMin = std::numeric_limits<float>::min();
@ -42,98 +38,51 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
float tolerance);
// Reader for binary files consisting of an arbitrary long sequence of elements
// having type T. It is possible to read and cast to another type D at once.
template <typename T, typename D = T>
class BinaryFileReader {
// File reader interface.
class FileReader {
public:
explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0)
: is_(file_path, std::ios::binary | std::ios::ate),
data_length_(is_.tellg() / sizeof(T)),
chunk_size_(chunk_size) {
RTC_CHECK(is_);
SeekBeginning();
buf_.resize(chunk_size_);
}
BinaryFileReader(const BinaryFileReader&) = delete;
BinaryFileReader& operator=(const BinaryFileReader&) = delete;
~BinaryFileReader() = default;
size_t data_length() const { return data_length_; }
bool ReadValue(D* dst) {
if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst), sizeof(T));
} else {
T v;
is_.read(reinterpret_cast<char*>(&v), sizeof(T));
*dst = static_cast<D>(v);
}
return is_.gcount() == sizeof(T);
}
// If |chunk_size| was specified in the ctor, it will check that the size of
// |dst| equals |chunk_size|.
bool ReadChunk(rtc::ArrayView<D> dst) {
RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size()));
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
} else {
is_.read(reinterpret_cast<char*>(buf_.data()), bytes_to_read);
std::transform(buf_.begin(), buf_.end(), dst.begin(),
[](const T& v) -> D { return static_cast<D>(v); });
}
return is_.gcount() == bytes_to_read;
}
void SeekForward(size_t items) { is_.seekg(items * sizeof(T), is_.cur); }
void SeekBeginning() { is_.seekg(0, is_.beg); }
private:
std::ifstream is_;
const size_t data_length_;
const size_t chunk_size_;
std::vector<T> buf_;
virtual ~FileReader() = default;
// Number of values in the file.
virtual int size() const = 0;
// Reads `dst.size()` float values into `dst`, advances the internal file
// position according to the number of read bytes and returns true if the
// values are correctly read. If the number of remaining bytes in the file is
// not sufficient to read `dst.size()` float values, `dst` is partially
// modified and false is returned.
virtual bool ReadChunk(rtc::ArrayView<float> dst) = 0;
// Reads a single float value, advances the internal file position according
// to the number of read bytes and returns true if the value is correctly
// read. If the number of remaining bytes in the file is not sufficient to
// read one float, `dst` is not modified and false is returned.
virtual bool ReadValue(float& dst) = 0;
// Advances the internal file position by `hop` float values.
virtual void SeekForward(int hop) = 0;
// Resets the internal file position to BOF.
virtual void SeekBeginning() = 0;
};
// Writer for binary files.
template <typename T>
class BinaryFileWriter {
public:
explicit BinaryFileWriter(const std::string& file_path)
: os_(file_path, std::ios::binary) {}
BinaryFileWriter(const BinaryFileWriter&) = delete;
BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
~BinaryFileWriter() = default;
static_assert(std::is_arithmetic<T>::value, "");
void WriteChunk(rtc::ArrayView<const T> value) {
const std::streamsize bytes_to_write = value.size() * sizeof(T);
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
}
private:
std::ofstream os_;
// File reader for files that contain `num_chunks` chunks with size equal to
// `chunk_size`.
struct ChunksFileReader {
const int chunk_size;
const int num_chunks;
std::unique_ptr<FileReader> reader;
};
// Factories for resource file readers.
// The functions below return a pair where the first item is a reader unique
// pointer and the second the number of chunks that can be read from the file.
// Creates a reader for the PCM samples that casts from S16 to float and reads
// chunks with length |frame_length|.
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
CreatePcmSamplesReader(const size_t frame_length);
// Creates a reader for the pitch buffer content at 24 kHz.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreatePitchBuffer24kHzReader();
// Creates a reader for the the LP residual coefficients and the pitch period
// and gain values.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateLpResidualAndPitchPeriodGainReader();
// Creates a reader for the VAD probabilities.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateVadProbsReader();
// Creates a reader for the PCM S16 samples file.
std::unique_ptr<FileReader> CreatePcmSamplesReader();
constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
constexpr size_t kNumPitchBufSquareEnergies = 385;
constexpr size_t kPitchTestDataSize =
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
// Creates a reader for the 24 kHz pitch buffer test data.
ChunksFileReader CreatePitchBuffer24kHzReader();
// Creates a reader for the LP residual and pitch information test data.
ChunksFileReader CreateLpResidualAndPitchInfoReader();
// Creates a reader for the sequence of GRU input vectors.
std::unique_ptr<FileReader> CreateGruInputReader();
// Creates a reader for the VAD probabilities test data.
std::unique_ptr<FileReader> CreateVadProbsReader();
// Class to retrieve a test pitch buffer content and the expected output for the
// analysis steps.
@ -141,20 +90,40 @@ class PitchTestData {
public:
PitchTestData();
~PitchTestData();
rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
GetPitchBufSquareEnergiesView() const;
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
GetPitchBufAutoCorrCoeffsView() const;
rtc::ArrayView<const float, kBufSize24kHz> PitchBuffer24kHzView() const {
return pitch_buffer_24k_;
}
rtc::ArrayView<const float, kRefineNumLags24kHz> SquareEnergies24kHzView()
const {
return square_energies_24k_;
}
rtc::ArrayView<const float, kNumLags12kHz> AutoCorrelation12kHzView() const {
return auto_correlation_12k_;
}
private:
std::array<float, kPitchTestDataSize> test_data_;
std::array<float, kBufSize24kHz> pitch_buffer_24k_;
std::array<float, kRefineNumLags24kHz> square_energies_24k_;
std::array<float, kNumLags12kHz> auto_correlation_12k_;
};
// Returns true if the given optimization is available.
bool IsOptimizationAvailable(Optimization optimization);
// Writer for binary files.
class FileWriter {
public:
explicit FileWriter(absl::string_view file_path)
: os_(std::string(file_path), std::ios::binary) {}
FileWriter(const FileWriter&) = delete;
FileWriter& operator=(const FileWriter&) = delete;
~FileWriter() = default;
void WriteChunk(rtc::ArrayView<const float> value) {
const std::streamsize bytes_to_write = value.size() * sizeof(float);
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
}
private:
std::ofstream os_;
};
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,114 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
#include "rtc_base/system/arch.h"
#if defined(WEBRTC_HAS_NEON)
#include <arm_neon.h>
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
#include <numeric>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/system/arch.h"
namespace webrtc {
namespace rnn_vad {
// Provides optimizations for mathematical operations having vectors as
// operand(s).
class VectorMath {
public:
explicit VectorMath(AvailableCpuFeatures cpu_features)
: cpu_features_(cpu_features) {}
// Computes the dot product between two equally sized vectors.
float DotProduct(rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y) const {
RTC_DCHECK_EQ(x.size(), y.size());
#if defined(WEBRTC_ARCH_X86_FAMILY)
if (cpu_features_.avx2) {
return DotProductAvx2(x, y);
} else if (cpu_features_.sse2) {
__m128 accumulator = _mm_setzero_ps();
constexpr int kBlockSizeLog2 = 2;
constexpr int kBlockSize = 1 << kBlockSizeLog2;
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
<< kBlockSizeLog2;
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
RTC_DCHECK_LE(i + kBlockSize, x.size());
const __m128 x_i = _mm_loadu_ps(&x[i]);
const __m128 y_i = _mm_loadu_ps(&y[i]);
// Multiply-add.
const __m128 z_j = _mm_mul_ps(x_i, y_i);
accumulator = _mm_add_ps(accumulator, z_j);
}
// Reduce `accumulator` by addition.
__m128 high = _mm_movehl_ps(accumulator, accumulator);
accumulator = _mm_add_ps(accumulator, high);
high = _mm_shuffle_ps(accumulator, accumulator, 1);
accumulator = _mm_add_ps(accumulator, high);
float dot_product = _mm_cvtss_f32(accumulator);
// Add the result for the last block if incomplete.
for (int i = incomplete_block_index;
i < rtc::dchecked_cast<int>(x.size()); ++i) {
dot_product += x[i] * y[i];
}
return dot_product;
}
#elif defined(WEBRTC_HAS_NEON) && defined(WEBRTC_ARCH_ARM64)
if (cpu_features_.neon) {
float32x4_t accumulator = vdupq_n_f32(0.f);
constexpr int kBlockSizeLog2 = 2;
constexpr int kBlockSize = 1 << kBlockSizeLog2;
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
<< kBlockSizeLog2;
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
RTC_DCHECK_LE(i + kBlockSize, x.size());
const float32x4_t x_i = vld1q_f32(&x[i]);
const float32x4_t y_i = vld1q_f32(&y[i]);
accumulator = vfmaq_f32(accumulator, x_i, y_i);
}
// Reduce `accumulator` by addition.
const float32x2_t tmp =
vpadd_f32(vget_low_f32(accumulator), vget_high_f32(accumulator));
float dot_product = vget_lane_f32(vpadd_f32(tmp, vrev64_f32(tmp)), 0);
// Add the result for the last block if incomplete.
for (int i = incomplete_block_index;
i < rtc::dchecked_cast<int>(x.size()); ++i) {
dot_product += x[i] * y[i];
}
return dot_product;
}
#endif
return std::inner_product(x.begin(), x.end(), y.begin(), 0.f);
}
private:
float DotProductAvx2(rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y) const;
const AvailableCpuFeatures cpu_features_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_

View File

@ -0,0 +1,54 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <immintrin.h>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
namespace webrtc {
namespace rnn_vad {
float VectorMath::DotProductAvx2(rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y) const {
RTC_DCHECK(cpu_features_.avx2);
RTC_DCHECK_EQ(x.size(), y.size());
__m256 accumulator = _mm256_setzero_ps();
constexpr int kBlockSizeLog2 = 3;
constexpr int kBlockSize = 1 << kBlockSizeLog2;
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
<< kBlockSizeLog2;
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
RTC_DCHECK_LE(i + kBlockSize, x.size());
const __m256 x_i = _mm256_loadu_ps(&x[i]);
const __m256 y_i = _mm256_loadu_ps(&y[i]);
accumulator = _mm256_fmadd_ps(x_i, y_i, accumulator);
}
// Reduce `accumulator` by addition.
__m128 high = _mm256_extractf128_ps(accumulator, 1);
__m128 low = _mm256_extractf128_ps(accumulator, 0);
low = _mm_add_ps(high, low);
high = _mm_movehl_ps(high, low);
low = _mm_add_ps(high, low);
high = _mm_shuffle_ps(low, low, 1);
low = _mm_add_ss(high, low);
float dot_product = _mm_cvtss_f32(low);
// Add the result for the last block if incomplete.
for (int i = incomplete_block_index; i < rtc::dchecked_cast<int>(x.size());
++i) {
dot_product += x[i] * y[i];
}
return dot_product;
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -10,84 +10,59 @@
#include "modules/audio_processing/agc2/saturation_protector.h"
#include <memory>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/saturation_protector_buffer.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
constexpr float kMinLevelDbfs = -90.f;
constexpr int kPeakEnveloperSuperFrameLengthMs = 400;
constexpr float kMinMarginDb = 12.0f;
constexpr float kMaxMarginDb = 25.0f;
constexpr float kAttack = 0.9988493699365052f;
constexpr float kDecay = 0.9997697679981565f;
// Min/max margins are based on speech crest-factor.
constexpr float kMinMarginDb = 12.f;
constexpr float kMaxMarginDb = 25.f;
using saturation_protector_impl::RingBuffer;
} // namespace
bool RingBuffer::operator==(const RingBuffer& b) const {
RTC_DCHECK_LE(size_, buffer_.size());
RTC_DCHECK_LE(b.size_, b.buffer_.size());
if (size_ != b.size_) {
return false;
// Saturation protector state. Defined outside of `SaturationProtectorImpl` to
// implement check-point and restore ops.
struct SaturationProtectorState {
bool operator==(const SaturationProtectorState& s) const {
return headroom_db == s.headroom_db &&
peak_delay_buffer == s.peak_delay_buffer &&
max_peaks_dbfs == s.max_peaks_dbfs &&
time_since_push_ms == s.time_since_push_ms;
}
for (int i = 0, i0 = FrontIndex(), i1 = b.FrontIndex(); i < size_;
++i, ++i0, ++i1) {
if (buffer_[i0 % buffer_.size()] != b.buffer_[i1 % b.buffer_.size()]) {
return false;
}
inline bool operator!=(const SaturationProtectorState& s) const {
return !(*this == s);
}
return true;
}
void RingBuffer::Reset() {
next_ = 0;
size_ = 0;
}
float headroom_db;
SaturationProtectorBuffer peak_delay_buffer;
float max_peaks_dbfs;
int time_since_push_ms; // Time since the last ring buffer push operation.
};
void RingBuffer::PushBack(float v) {
RTC_DCHECK_GE(next_, 0);
RTC_DCHECK_GE(size_, 0);
RTC_DCHECK_LT(next_, buffer_.size());
RTC_DCHECK_LE(size_, buffer_.size());
buffer_[next_++] = v;
if (rtc::SafeEq(next_, buffer_.size())) {
next_ = 0;
}
if (rtc::SafeLt(size_, buffer_.size())) {
size_++;
}
}
absl::optional<float> RingBuffer::Front() const {
if (size_ == 0) {
return absl::nullopt;
}
RTC_DCHECK_LT(FrontIndex(), buffer_.size());
return buffer_[FrontIndex()];
}
bool SaturationProtectorState::operator==(
const SaturationProtectorState& b) const {
return margin_db == b.margin_db && peak_delay_buffer == b.peak_delay_buffer &&
max_peaks_dbfs == b.max_peaks_dbfs &&
time_since_push_ms == b.time_since_push_ms;
}
void ResetSaturationProtectorState(float initial_margin_db,
// Resets the saturation protector state.
void ResetSaturationProtectorState(float initial_headroom_db,
SaturationProtectorState& state) {
state.margin_db = initial_margin_db;
state.headroom_db = initial_headroom_db;
state.peak_delay_buffer.Reset();
state.max_peaks_dbfs = kMinLevelDbfs;
state.time_since_push_ms = 0;
}
void UpdateSaturationProtectorState(float speech_peak_dbfs,
// Updates `state` by analyzing the estimated speech level `speech_level_dbfs`
// and the peak level `peak_dbfs` for an observed frame. `state` must not be
// modified without calling this function.
void UpdateSaturationProtectorState(float peak_dbfs,
float speech_level_dbfs,
SaturationProtectorState& state) {
// Get the max peak over `kPeakEnveloperSuperFrameLengthMs` ms.
state.max_peaks_dbfs = std::max(state.max_peaks_dbfs, speech_peak_dbfs);
state.max_peaks_dbfs = std::max(state.max_peaks_dbfs, peak_dbfs);
state.time_since_push_ms += kFrameDurationMs;
if (rtc::SafeGt(state.time_since_push_ms, kPeakEnveloperSuperFrameLengthMs)) {
// Push `max_peaks_dbfs` back into the ring buffer.
@ -97,25 +72,112 @@ void UpdateSaturationProtectorState(float speech_peak_dbfs,
state.time_since_push_ms = 0;
}
// Update margin by comparing the estimated speech level and the delayed max
// speech peak power.
// TODO(alessiob): Check with aleloi@ why we use a delay and how to tune it.
// Update the headroom by comparing the estimated speech level and the delayed
// max speech peak.
const float delayed_peak_dbfs =
state.peak_delay_buffer.Front().value_or(state.max_peaks_dbfs);
const float difference_db = delayed_peak_dbfs - speech_level_dbfs;
if (difference_db > state.margin_db) {
if (difference_db > state.headroom_db) {
// Attack.
state.margin_db =
state.margin_db * kSaturationProtectorAttackConstant +
difference_db * (1.f - kSaturationProtectorAttackConstant);
state.headroom_db =
state.headroom_db * kAttack + difference_db * (1.0f - kAttack);
} else {
// Decay.
state.margin_db = state.margin_db * kSaturationProtectorDecayConstant +
difference_db * (1.f - kSaturationProtectorDecayConstant);
state.headroom_db =
state.headroom_db * kDecay + difference_db * (1.0f - kDecay);
}
state.margin_db =
rtc::SafeClamp<float>(state.margin_db, kMinMarginDb, kMaxMarginDb);
state.headroom_db =
rtc::SafeClamp<float>(state.headroom_db, kMinMarginDb, kMaxMarginDb);
}
// Saturation protector which recommends a headroom based on the recent peaks.
class SaturationProtectorImpl : public SaturationProtector {
public:
explicit SaturationProtectorImpl(float initial_headroom_db,
int adjacent_speech_frames_threshold,
ApmDataDumper* apm_data_dumper)
: apm_data_dumper_(apm_data_dumper),
initial_headroom_db_(initial_headroom_db),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold) {
Reset();
}
SaturationProtectorImpl(const SaturationProtectorImpl&) = delete;
SaturationProtectorImpl& operator=(const SaturationProtectorImpl&) = delete;
~SaturationProtectorImpl() = default;
float HeadroomDb() override { return headroom_db_; }
void Analyze(float speech_probability,
float peak_dbfs,
float speech_level_dbfs) override {
if (speech_probability < kVadConfidenceThreshold) {
// Not a speech frame.
if (adjacent_speech_frames_threshold_ > 1) {
// When two or more adjacent speech frames are required in order to
// update the state, we need to decide whether to discard or confirm the
// updates based on the speech sequence length.
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// First non-speech frame after a long enough sequence of speech
// frames. Update the reliable state.
reliable_state_ = preliminary_state_;
} else if (num_adjacent_speech_frames_ > 0) {
// First non-speech frame after a too short sequence of speech frames.
// Reset to the last reliable state.
preliminary_state_ = reliable_state_;
}
}
num_adjacent_speech_frames_ = 0;
} else {
// Speech frame observed.
num_adjacent_speech_frames_++;
// Update preliminary level estimate.
UpdateSaturationProtectorState(peak_dbfs, speech_level_dbfs,
preliminary_state_);
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// `preliminary_state_` is now reliable. Update the headroom.
headroom_db_ = preliminary_state_.headroom_db;
}
}
DumpDebugData();
}
void Reset() override {
num_adjacent_speech_frames_ = 0;
headroom_db_ = initial_headroom_db_;
ResetSaturationProtectorState(initial_headroom_db_, preliminary_state_);
ResetSaturationProtectorState(initial_headroom_db_, reliable_state_);
}
private:
void DumpDebugData() {
apm_data_dumper_->DumpRaw(
"agc2_saturation_protector_preliminary_max_peak_dbfs",
preliminary_state_.max_peaks_dbfs);
apm_data_dumper_->DumpRaw(
"agc2_saturation_protector_reliable_max_peak_dbfs",
reliable_state_.max_peaks_dbfs);
}
ApmDataDumper* const apm_data_dumper_;
const float initial_headroom_db_;
const int adjacent_speech_frames_threshold_;
int num_adjacent_speech_frames_;
float headroom_db_;
SaturationProtectorState preliminary_state_;
SaturationProtectorState reliable_state_;
};
} // namespace
std::unique_ptr<SaturationProtector> CreateSaturationProtector(
float initial_headroom_db,
int adjacent_speech_frames_threshold,
ApmDataDumper* apm_data_dumper) {
return std::make_unique<SaturationProtectorImpl>(
initial_headroom_db, adjacent_speech_frames_threshold, apm_data_dumper);
}
} // namespace webrtc

View File

@ -11,71 +11,35 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
#include <array>
#include "absl/types/optional.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/numerics/safe_compare.h"
#include <memory>
namespace webrtc {
namespace saturation_protector_impl {
class ApmDataDumper;
// Ring buffer which only supports (i) push back and (ii) read oldest item.
class RingBuffer {
// Saturation protector. Analyzes peak levels and recommends a headroom to
// reduce the chances of clipping.
class SaturationProtector {
public:
bool operator==(const RingBuffer& b) const;
inline bool operator!=(const RingBuffer& b) const { return !(*this == b); }
virtual ~SaturationProtector() = default;
// Maximum number of values that the buffer can contain.
int Capacity() const { return buffer_.size(); }
// Number of values in the buffer.
int Size() const { return size_; }
// Returns the recommended headroom in dB.
virtual float HeadroomDb() = 0;
void Reset();
// Pushes back `v`. If the buffer is full, the oldest value is replaced.
void PushBack(float v);
// Returns the oldest item in the buffer. Returns an empty value if the
// buffer is empty.
absl::optional<float> Front() const;
// Analyzes the peak level of a 10 ms frame along with its speech probability
// and the current speech level estimate to update the recommended headroom.
virtual void Analyze(float speech_probability,
float peak_dbfs,
float speech_level_dbfs) = 0;
private:
inline int FrontIndex() const {
return rtc::SafeEq(size_, buffer_.size()) ? next_ : 0;
}
// `buffer_` has `size_` elements (up to the size of `buffer_`) and `next_` is
// the position where the next new value is written in `buffer_`.
std::array<float, kPeakEnveloperBufferSize> buffer_;
int next_ = 0;
int size_ = 0;
// Resets the internal state.
virtual void Reset() = 0;
};
} // namespace saturation_protector_impl
// Saturation protector state. Exposed publicly for check-pointing and restore
// ops.
struct SaturationProtectorState {
bool operator==(const SaturationProtectorState& s) const;
inline bool operator!=(const SaturationProtectorState& s) const {
return !(*this == s);
}
float margin_db; // Recommended margin.
saturation_protector_impl::RingBuffer peak_delay_buffer;
float max_peaks_dbfs;
int time_since_push_ms; // Time since the last ring buffer push operation.
};
// Resets the saturation protector state.
void ResetSaturationProtectorState(float initial_margin_db,
SaturationProtectorState& state);
// Updates `state` by analyzing the estimated speech level `speech_level_dbfs`
// and the peak power `speech_peak_dbfs` for an observed frame which is
// reliably classified as "speech". `state` must not be modified without calling
// this function.
void UpdateSaturationProtectorState(float speech_peak_dbfs,
float speech_level_dbfs,
SaturationProtectorState& state);
// Creates a saturation protector that starts at `initial_headroom_db`.
std::unique_ptr<SaturationProtector> CreateSaturationProtector(
float initial_headroom_db,
int adjacent_speech_frames_threshold,
ApmDataDumper* apm_data_dumper);
} // namespace webrtc

View File

@ -0,0 +1,77 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/saturation_protector_buffer.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
SaturationProtectorBuffer::SaturationProtectorBuffer() = default;
SaturationProtectorBuffer::~SaturationProtectorBuffer() = default;
bool SaturationProtectorBuffer::operator==(
const SaturationProtectorBuffer& b) const {
RTC_DCHECK_LE(size_, buffer_.size());
RTC_DCHECK_LE(b.size_, b.buffer_.size());
if (size_ != b.size_) {
return false;
}
for (int i = 0, i0 = FrontIndex(), i1 = b.FrontIndex(); i < size_;
++i, ++i0, ++i1) {
if (buffer_[i0 % buffer_.size()] != b.buffer_[i1 % b.buffer_.size()]) {
return false;
}
}
return true;
}
int SaturationProtectorBuffer::Capacity() const {
return buffer_.size();
}
int SaturationProtectorBuffer::Size() const {
return size_;
}
void SaturationProtectorBuffer::Reset() {
next_ = 0;
size_ = 0;
}
void SaturationProtectorBuffer::PushBack(float v) {
RTC_DCHECK_GE(next_, 0);
RTC_DCHECK_GE(size_, 0);
RTC_DCHECK_LT(next_, buffer_.size());
RTC_DCHECK_LE(size_, buffer_.size());
buffer_[next_++] = v;
if (rtc::SafeEq(next_, buffer_.size())) {
next_ = 0;
}
if (rtc::SafeLt(size_, buffer_.size())) {
size_++;
}
}
absl::optional<float> SaturationProtectorBuffer::Front() const {
if (size_ == 0) {
return absl::nullopt;
}
RTC_DCHECK_LT(FrontIndex(), buffer_.size());
return buffer_[FrontIndex()];
}
int SaturationProtectorBuffer::FrontIndex() const {
return rtc::SafeEq(size_, buffer_.size()) ? next_ : 0;
}
} // namespace webrtc

View File

@ -0,0 +1,59 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_
#include <array>
#include "absl/types/optional.h"
#include "modules/audio_processing/agc2/agc2_common.h"
namespace webrtc {
// Ring buffer for the saturation protector which only supports (i) push back
// and (ii) read oldest item.
class SaturationProtectorBuffer {
public:
SaturationProtectorBuffer();
~SaturationProtectorBuffer();
bool operator==(const SaturationProtectorBuffer& b) const;
inline bool operator!=(const SaturationProtectorBuffer& b) const {
return !(*this == b);
}
// Maximum number of values that the buffer can contain.
int Capacity() const;
// Number of values in the buffer.
int Size() const;
void Reset();
// Pushes back `v`. If the buffer is full, the oldest value is replaced.
void PushBack(float v);
// Returns the oldest item in the buffer. Returns an empty value if the
// buffer is empty.
absl::optional<float> Front() const;
private:
int FrontIndex() const;
// `buffer_` has `size_` elements (up to the size of `buffer_`) and `next_` is
// the position where the next new value is written in `buffer_`.
std::array<float, kSaturationProtectorBufferSize> buffer_;
int next_ = 0;
int size_ = 0;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_

View File

@ -1,177 +0,0 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/signal_classifier.h"
#include <algorithm>
#include <numeric>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/down_sampler.h"
#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "system_wrappers/include/cpu_features_wrapper.h"
namespace webrtc {
namespace {
bool IsSse2Available() {
#if defined(WEBRTC_ARCH_X86_FAMILY)
return GetCPUInfo(kSSE2) != 0;
#else
return false;
#endif
}
void RemoveDcLevel(rtc::ArrayView<float> x) {
RTC_DCHECK_LT(0, x.size());
float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f);
mean /= x.size();
for (float& v : x) {
v -= mean;
}
}
void PowerSpectrum(const OouraFft* ooura_fft,
rtc::ArrayView<const float> x,
rtc::ArrayView<float> spectrum) {
RTC_DCHECK_EQ(65, spectrum.size());
RTC_DCHECK_EQ(128, x.size());
float X[128];
std::copy(x.data(), x.data() + x.size(), X);
ooura_fft->Fft(X);
float* X_p = X;
RTC_DCHECK_EQ(X_p, &X[0]);
spectrum[0] = (*X_p) * (*X_p);
++X_p;
RTC_DCHECK_EQ(X_p, &X[1]);
spectrum[64] = (*X_p) * (*X_p);
for (int k = 1; k < 64; ++k) {
++X_p;
RTC_DCHECK_EQ(X_p, &X[2 * k]);
spectrum[k] = (*X_p) * (*X_p);
++X_p;
RTC_DCHECK_EQ(X_p, &X[2 * k + 1]);
spectrum[k] += (*X_p) * (*X_p);
}
}
webrtc::SignalClassifier::SignalType ClassifySignal(
rtc::ArrayView<const float> signal_spectrum,
rtc::ArrayView<const float> noise_spectrum,
ApmDataDumper* data_dumper) {
int num_stationary_bands = 0;
int num_highly_nonstationary_bands = 0;
// Detect stationary and highly nonstationary bands.
for (size_t k = 1; k < 40; k++) {
if (signal_spectrum[k] < 3 * noise_spectrum[k] &&
signal_spectrum[k] * 3 > noise_spectrum[k]) {
++num_stationary_bands;
} else if (signal_spectrum[k] > 9 * noise_spectrum[k]) {
++num_highly_nonstationary_bands;
}
}
data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands);
data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1,
&num_highly_nonstationary_bands);
// Use the detected number of bands to classify the overall signal
// stationarity.
if (num_stationary_bands > 15) {
return SignalClassifier::SignalType::kStationary;
} else {
return SignalClassifier::SignalType::kNonStationary;
}
}
} // namespace
SignalClassifier::FrameExtender::FrameExtender(size_t frame_size,
size_t extended_frame_size)
: x_old_(extended_frame_size - frame_size, 0.f) {}
SignalClassifier::FrameExtender::~FrameExtender() = default;
void SignalClassifier::FrameExtender::ExtendFrame(
rtc::ArrayView<const float> x,
rtc::ArrayView<float> x_extended) {
RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size());
std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data());
std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size());
std::copy(x_extended.data() + x_extended.size() - x_old_.size(),
x_extended.data() + x_extended.size(), x_old_.data());
}
SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper)
: data_dumper_(data_dumper),
down_sampler_(data_dumper_),
noise_spectrum_estimator_(data_dumper_),
ooura_fft_(IsSse2Available()) {
Initialize(48000);
}
SignalClassifier::~SignalClassifier() {}
void SignalClassifier::Initialize(int sample_rate_hz) {
down_sampler_.Initialize(sample_rate_hz);
noise_spectrum_estimator_.Initialize();
frame_extender_.reset(new FrameExtender(80, 128));
sample_rate_hz_ = sample_rate_hz;
initialization_frames_left_ = 2;
consistent_classification_counter_ = 3;
last_signal_type_ = SignalClassifier::SignalType::kNonStationary;
}
SignalClassifier::SignalType SignalClassifier::Analyze(
rtc::ArrayView<const float> signal) {
RTC_DCHECK_EQ(signal.size(), sample_rate_hz_ / 100);
// Compute the signal power spectrum.
float downsampled_frame[80];
down_sampler_.DownSample(signal, downsampled_frame);
float extended_frame[128];
frame_extender_->ExtendFrame(downsampled_frame, extended_frame);
RemoveDcLevel(extended_frame);
float signal_spectrum[65];
PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum);
// Classify the signal based on the estimate of the noise spectrum and the
// signal spectrum estimate.
const SignalType signal_type = ClassifySignal(
signal_spectrum, noise_spectrum_estimator_.GetNoiseSpectrum(),
data_dumper_);
// Update the noise spectrum based on the signal spectrum.
noise_spectrum_estimator_.Update(signal_spectrum,
initialization_frames_left_ > 0);
// Update the number of frames until a reliable signal spectrum is achieved.
initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1);
if (last_signal_type_ == signal_type) {
consistent_classification_counter_ =
std::max(0, consistent_classification_counter_ - 1);
} else {
last_signal_type_ = signal_type;
consistent_classification_counter_ = 3;
}
if (consistent_classification_counter_ > 0) {
return SignalClassifier::SignalType::kNonStationary;
}
return signal_type;
}
} // namespace webrtc

View File

@ -1,73 +0,0 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "common_audio/third_party/ooura/fft_size_128/ooura_fft.h"
#include "modules/audio_processing/agc2/down_sampler.h"
#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
namespace webrtc {
class ApmDataDumper;
class AudioBuffer;
class SignalClassifier {
public:
enum class SignalType { kNonStationary, kStationary };
explicit SignalClassifier(ApmDataDumper* data_dumper);
SignalClassifier() = delete;
SignalClassifier(const SignalClassifier&) = delete;
SignalClassifier& operator=(const SignalClassifier&) = delete;
~SignalClassifier();
void Initialize(int sample_rate_hz);
SignalType Analyze(rtc::ArrayView<const float> signal);
private:
class FrameExtender {
public:
FrameExtender(size_t frame_size, size_t extended_frame_size);
FrameExtender() = delete;
FrameExtender(const FrameExtender&) = delete;
FrameExtender& operator=(const FrameExtender&) = delete;
~FrameExtender();
void ExtendFrame(rtc::ArrayView<const float> x,
rtc::ArrayView<float> x_extended);
private:
std::vector<float> x_old_;
};
ApmDataDumper* const data_dumper_;
DownSampler down_sampler_;
std::unique_ptr<FrameExtender> frame_extender_;
NoiseSpectrumEstimator noise_spectrum_estimator_;
int sample_rate_hz_;
int initialization_frames_left_;
int consistent_classification_counter_;
SignalType last_signal_type_;
const OouraFft ooura_fft_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_

View File

@ -0,0 +1,174 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/speech_level_estimator.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
float ClampLevelEstimateDbfs(float level_estimate_dbfs) {
return rtc::SafeClamp<float>(level_estimate_dbfs, -90.0f, 30.0f);
}
// Returns the initial speech level estimate needed to apply the initial gain.
float GetInitialSpeechLevelEstimateDbfs(
const AudioProcessing::Config::GainController2::AdaptiveDigital& config) {
return ClampLevelEstimateDbfs(-kSaturationProtectorInitialHeadroomDb -
config.initial_gain_db - config.headroom_db);
}
} // namespace
bool SpeechLevelEstimator::LevelEstimatorState::operator==(
const SpeechLevelEstimator::LevelEstimatorState& b) const {
return time_to_confidence_ms == b.time_to_confidence_ms &&
level_dbfs.numerator == b.level_dbfs.numerator &&
level_dbfs.denominator == b.level_dbfs.denominator;
}
float SpeechLevelEstimator::LevelEstimatorState::Ratio::GetRatio() const {
RTC_DCHECK_NE(denominator, 0.f);
return numerator / denominator;
}
SpeechLevelEstimator::SpeechLevelEstimator(
ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
int adjacent_speech_frames_threshold)
: apm_data_dumper_(apm_data_dumper),
initial_speech_level_dbfs_(GetInitialSpeechLevelEstimateDbfs(config)),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
level_dbfs_(initial_speech_level_dbfs_),
// TODO(bugs.webrtc.org/7494): Remove init below when AGC2 input volume
// controller temporal dependency removed.
is_confident_(false) {
RTC_DCHECK(apm_data_dumper_);
RTC_DCHECK_GE(adjacent_speech_frames_threshold_, 1);
Reset();
}
void SpeechLevelEstimator::Update(float rms_dbfs,
float peak_dbfs,
float speech_probability) {
RTC_DCHECK_GT(rms_dbfs, -150.0f);
RTC_DCHECK_LT(rms_dbfs, 50.0f);
RTC_DCHECK_GT(peak_dbfs, -150.0f);
RTC_DCHECK_LT(peak_dbfs, 50.0f);
RTC_DCHECK_GE(speech_probability, 0.0f);
RTC_DCHECK_LE(speech_probability, 1.0f);
if (speech_probability < kVadConfidenceThreshold) {
// Not a speech frame.
if (adjacent_speech_frames_threshold_ > 1) {
// When two or more adjacent speech frames are required in order to update
// the state, we need to decide whether to discard or confirm the updates
// based on the speech sequence length.
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// First non-speech frame after a long enough sequence of speech frames.
// Update the reliable state.
reliable_state_ = preliminary_state_;
} else if (num_adjacent_speech_frames_ > 0) {
// First non-speech frame after a too short sequence of speech frames.
// Reset to the last reliable state.
preliminary_state_ = reliable_state_;
}
}
num_adjacent_speech_frames_ = 0;
} else {
// Speech frame observed.
num_adjacent_speech_frames_++;
// Update preliminary level estimate.
RTC_DCHECK_GE(preliminary_state_.time_to_confidence_ms, 0);
const bool buffer_is_full = preliminary_state_.time_to_confidence_ms == 0;
if (!buffer_is_full) {
preliminary_state_.time_to_confidence_ms -= kFrameDurationMs;
}
// Weighted average of levels with speech probability as weight.
RTC_DCHECK_GT(speech_probability, 0.0f);
const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.0f;
preliminary_state_.level_dbfs.numerator =
preliminary_state_.level_dbfs.numerator * leak_factor +
rms_dbfs * speech_probability;
preliminary_state_.level_dbfs.denominator =
preliminary_state_.level_dbfs.denominator * leak_factor +
speech_probability;
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// `preliminary_state_` is now reliable. Update the last level estimation.
level_dbfs_ = ClampLevelEstimateDbfs(level_dbfs);
}
}
UpdateIsConfident();
DumpDebugData();
}
void SpeechLevelEstimator::UpdateIsConfident() {
if (adjacent_speech_frames_threshold_ == 1) {
// Ignore `reliable_state_` when a single frame is enough to update the
// level estimate (because it is not used).
is_confident_ = preliminary_state_.time_to_confidence_ms == 0;
return;
}
// Once confident, it remains confident.
RTC_DCHECK(reliable_state_.time_to_confidence_ms != 0 ||
preliminary_state_.time_to_confidence_ms == 0);
// During the first long enough speech sequence, `reliable_state_` must be
// ignored since `preliminary_state_` is used.
is_confident_ =
reliable_state_.time_to_confidence_ms == 0 ||
(num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_ &&
preliminary_state_.time_to_confidence_ms == 0);
}
void SpeechLevelEstimator::Reset() {
ResetLevelEstimatorState(preliminary_state_);
ResetLevelEstimatorState(reliable_state_);
level_dbfs_ = initial_speech_level_dbfs_;
num_adjacent_speech_frames_ = 0;
}
void SpeechLevelEstimator::ResetLevelEstimatorState(
LevelEstimatorState& state) const {
state.time_to_confidence_ms = kLevelEstimatorTimeToConfidenceMs;
state.level_dbfs.numerator = initial_speech_level_dbfs_;
state.level_dbfs.denominator = 1.0f;
}
void SpeechLevelEstimator::DumpDebugData() const {
if (!apm_data_dumper_)
return;
apm_data_dumper_->DumpRaw("agc2_speech_level_dbfs", level_dbfs_);
apm_data_dumper_->DumpRaw("agc2_speech_level_is_confident", is_confident_);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_num_adjacent_speech_frames",
num_adjacent_speech_frames_);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_preliminary_level_estimate_num",
preliminary_state_.level_dbfs.numerator);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_preliminary_level_estimate_den",
preliminary_state_.level_dbfs.denominator);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_preliminary_time_to_confidence_ms",
preliminary_state_.time_to_confidence_ms);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_reliable_time_to_confidence_ms",
reliable_state_.time_to_confidence_ms);
}
} // namespace webrtc

View File

@ -8,40 +8,37 @@
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SPEECH_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SPEECH_LEVEL_ESTIMATOR_H_
#include <stddef.h>
#include <type_traits>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_processing.h"
namespace webrtc {
class ApmDataDumper;
// Level estimator for the digital adaptive gain controller.
class AdaptiveModeLevelEstimator {
// Active speech level estimator based on the analysis of the following
// framewise properties: RMS level (dBFS), peak level (dBFS), speech
// probability.
class SpeechLevelEstimator {
public:
explicit AdaptiveModeLevelEstimator(ApmDataDumper* apm_data_dumper);
AdaptiveModeLevelEstimator(const AdaptiveModeLevelEstimator&) = delete;
AdaptiveModeLevelEstimator& operator=(const AdaptiveModeLevelEstimator&) =
delete;
AdaptiveModeLevelEstimator(
SpeechLevelEstimator(
ApmDataDumper* apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator level_estimator,
int adjacent_speech_frames_threshold,
float initial_saturation_margin_db,
float extra_saturation_margin_db);
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
int adjacent_speech_frames_threshold);
SpeechLevelEstimator(const SpeechLevelEstimator&) = delete;
SpeechLevelEstimator& operator=(const SpeechLevelEstimator&) = delete;
// Updates the level estimation.
void Update(const VadLevelAnalyzer::Result& vad_data);
void Update(float rms_dbfs, float peak_dbfs, float speech_probability);
// Returns the estimated speech plus noise level.
float level_dbfs() const { return level_dbfs_; }
// Returns true if the estimator is confident on its current estimate.
bool IsConfident() const;
bool is_confident() const { return is_confident_; }
void Reset();
@ -52,35 +49,33 @@ class AdaptiveModeLevelEstimator {
inline bool operator!=(const LevelEstimatorState& s) const {
return !(*this == s);
}
// TODO(bugs.webrtc.org/7494): Remove `time_to_confidence_ms` if redundant.
int time_to_confidence_ms;
struct Ratio {
float numerator;
float denominator;
float GetRatio() const;
};
// TODO(crbug.com/webrtc/7494): Remove time_to_full_buffer_ms if redundant.
int time_to_full_buffer_ms;
Ratio level_dbfs;
SaturationProtectorState saturation_protector;
} level_dbfs;
};
static_assert(std::is_trivially_copyable<LevelEstimatorState>::value, "");
void UpdateIsConfident();
void ResetLevelEstimatorState(LevelEstimatorState& state) const;
void DumpDebugData() const;
ApmDataDumper* const apm_data_dumper_;
const AudioProcessing::Config::GainController2::LevelEstimator
level_estimator_type_;
const float initial_speech_level_dbfs_;
const int adjacent_speech_frames_threshold_;
const float initial_saturation_margin_db_;
const float extra_saturation_margin_db_;
LevelEstimatorState preliminary_state_;
LevelEstimatorState reliable_state_;
float level_dbfs_;
bool is_confident_;
int num_adjacent_speech_frames_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
#endif // MODULES_AUDIO_PROCESSING_AGC2_SPEECH_LEVEL_ESTIMATOR_H_

View File

@ -0,0 +1,105 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/speech_probability_buffer.h"
#include <algorithm>
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr float kActivityThreshold = 0.9f;
constexpr int kNumAnalysisFrames = 100;
// We use 12 in AGC2 adaptive digital, but with a slightly different logic.
constexpr int kTransientWidthThreshold = 7;
} // namespace
SpeechProbabilityBuffer::SpeechProbabilityBuffer(
float low_probability_threshold)
: low_probability_threshold_(low_probability_threshold),
probabilities_(kNumAnalysisFrames) {
RTC_DCHECK_GE(low_probability_threshold, 0.0f);
RTC_DCHECK_LE(low_probability_threshold, 1.0f);
RTC_DCHECK(!probabilities_.empty());
}
void SpeechProbabilityBuffer::Update(float probability) {
// Remove the oldest entry if the circular buffer is full.
if (buffer_is_full_) {
const float oldest_probability = probabilities_[buffer_index_];
sum_probabilities_ -= oldest_probability;
}
// Check for transients.
if (probability <= low_probability_threshold_) {
// Set a probability lower than the threshold to zero.
probability = 0.0f;
// Check if this has been a transient.
if (num_high_probability_observations_ <= kTransientWidthThreshold) {
RemoveTransient();
}
num_high_probability_observations_ = 0;
} else if (num_high_probability_observations_ <= kTransientWidthThreshold) {
++num_high_probability_observations_;
}
// Update the circular buffer and the current sum.
probabilities_[buffer_index_] = probability;
sum_probabilities_ += probability;
// Increment the buffer index and check for wrap-around.
if (++buffer_index_ >= kNumAnalysisFrames) {
buffer_index_ = 0;
buffer_is_full_ = true;
}
}
void SpeechProbabilityBuffer::RemoveTransient() {
// Don't expect to be here if high-activity region is longer than
// `kTransientWidthThreshold` or there has not been any transient.
RTC_DCHECK_LE(num_high_probability_observations_, kTransientWidthThreshold);
// Replace previously added probabilities with zero.
int index =
(buffer_index_ > 0) ? (buffer_index_ - 1) : (kNumAnalysisFrames - 1);
while (num_high_probability_observations_-- > 0) {
sum_probabilities_ -= probabilities_[index];
probabilities_[index] = 0.0f;
// Update the circular buffer index.
index = (index > 0) ? (index - 1) : (kNumAnalysisFrames - 1);
}
}
bool SpeechProbabilityBuffer::IsActiveSegment() const {
if (!buffer_is_full_) {
return false;
}
if (sum_probabilities_ < kActivityThreshold * kNumAnalysisFrames) {
return false;
}
return true;
}
void SpeechProbabilityBuffer::Reset() {
sum_probabilities_ = 0.0f;
// Empty the circular buffer.
buffer_index_ = 0;
buffer_is_full_ = false;
num_high_probability_observations_ = 0;
}
} // namespace webrtc

View File

@ -0,0 +1,80 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_
#include <vector>
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
// This class implements a circular buffer that stores speech probabilities
// for a speech segment and estimates speech activity for that segment.
class SpeechProbabilityBuffer {
public:
// Ctor. The value of `low_probability_threshold` is required to be on the
// range [0.0f, 1.0f].
explicit SpeechProbabilityBuffer(float low_probability_threshold);
~SpeechProbabilityBuffer() {}
SpeechProbabilityBuffer(const SpeechProbabilityBuffer&) = delete;
SpeechProbabilityBuffer& operator=(const SpeechProbabilityBuffer&) = delete;
// Adds `probability` in the buffer and computes an updatds sum of the buffer
// probabilities. Value of `probability` is required to be on the range
// [0.0f, 1.0f].
void Update(float probability);
// Resets the histogram, forgets the past.
void Reset();
// Returns true if the segment is active (a long enough segment with an
// average speech probability above `low_probability_threshold`).
bool IsActiveSegment() const;
private:
void RemoveTransient();
// Use only for testing.
float GetSumProbabilities() const { return sum_probabilities_; }
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
CheckSumAfterInitialization);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterUpdate);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterReset);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
CheckSumAfterTransientNotRemoved);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
CheckSumAfterTransientRemoved);
const float low_probability_threshold_;
// Sum of probabilities stored in `probabilities_`. Must be updated if
// `probabilities_` is updated.
float sum_probabilities_ = 0.0f;
// Circular buffer for probabilities.
std::vector<float> probabilities_;
// Current index of the circular buffer, where the newest data will be written
// to, therefore, pointing to the oldest data if buffer is full.
int buffer_index_ = 0;
// Indicates if the buffer is full and adding a new value removes the oldest
// value.
int buffer_is_full_ = false;
int num_high_probability_observations_ = 0;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_

View File

@ -1,114 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/vad_with_level.h"
#include <algorithm>
#include <array>
#include <cmath>
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector;
// Default VAD that combines a resampler and the RNN VAD.
// Computes the speech probability on the first channel.
class Vad : public VoiceActivityDetector {
public:
Vad() = default;
Vad(const Vad&) = delete;
Vad& operator=(const Vad&) = delete;
~Vad() = default;
float ComputeProbability(AudioFrameView<const float> frame) override {
// The source number of channels is 1, because we always use the 1st
// channel.
resampler_.InitializeIfNeeded(
/*sample_rate_hz=*/static_cast<int>(frame.samples_per_channel() * 100),
rnn_vad::kSampleRate24kHz,
/*num_channels=*/1);
std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
// Feed the 1st channel to the resampler.
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
work_frame, feature_vector);
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
}
private:
PushResampler<float> resampler_;
rnn_vad::FeaturesExtractor features_extractor_;
rnn_vad::RnnBasedVad rnn_vad_;
};
// Returns an updated version of `p_old` by using instant decay and the given
// `attack` on a new VAD probability value `p_new`.
float SmoothedVadProbability(float p_old, float p_new, float attack) {
RTC_DCHECK_GT(attack, 0.f);
RTC_DCHECK_LE(attack, 1.f);
if (p_new < p_old || attack == 1.f) {
// Instant decay (or no smoothing).
return p_new;
} else {
// Attack phase.
return attack * p_new + (1.f - attack) * p_old;
}
}
} // namespace
VadLevelAnalyzer::VadLevelAnalyzer()
: VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack,
std::make_unique<Vad>()) {}
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack)
: VadLevelAnalyzer(vad_probability_attack, std::make_unique<Vad>()) {}
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
std::unique_ptr<VoiceActivityDetector> vad)
: vad_(std::move(vad)), vad_probability_attack_(vad_probability_attack) {
RTC_DCHECK(vad_);
}
VadLevelAnalyzer::~VadLevelAnalyzer() = default;
VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
AudioFrameView<const float> frame) {
// Compute levels.
float peak = 0.f;
float rms = 0.f;
for (const auto& x : frame.channel(0)) {
peak = std::max(std::fabs(x), peak);
rms += x * x;
}
// Compute smoothed speech probability.
vad_probability_ = SmoothedVadProbability(
/*p_old=*/vad_probability_, /*p_new=*/vad_->ComputeProbability(frame),
vad_probability_attack_);
return {vad_probability_,
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
FloatS16ToDbfs(peak)};
}
} // namespace webrtc

View File

@ -1,58 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
#include <memory>
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
// Class to analyze voice activity and audio levels.
class VadLevelAnalyzer {
public:
struct Result {
float speech_probability; // Range: [0, 1].
float rms_dbfs; // Root mean square power (dBFS).
float peak_dbfs; // Peak power (dBFS).
};
// Voice Activity Detector (VAD) interface.
class VoiceActivityDetector {
public:
virtual ~VoiceActivityDetector() = default;
// Analyzes an audio frame and returns the speech probability.
virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
};
// Ctor. Uses the default VAD.
VadLevelAnalyzer();
explicit VadLevelAnalyzer(float vad_probability_attack);
// Ctor. Uses a custom `vad`.
VadLevelAnalyzer(float vad_probability_attack,
std::unique_ptr<VoiceActivityDetector> vad);
VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
~VadLevelAnalyzer();
// Computes the speech probability and the level for `frame`.
Result AnalyzeFrame(AudioFrameView<const float> frame);
private:
std::unique_ptr<VoiceActivityDetector> vad_;
const float vad_probability_attack_;
float vad_probability_ = 0.f;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_

View File

@ -0,0 +1,113 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/vad_wrapper.h"
#include <array>
#include <utility>
#include "api/array_view.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr int kNumFramesPerSecond = 100;
class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad {
public:
explicit MonoVadImpl(const AvailableCpuFeatures& cpu_features)
: features_extractor_(cpu_features), rnn_vad_(cpu_features) {}
MonoVadImpl(const MonoVadImpl&) = delete;
MonoVadImpl& operator=(const MonoVadImpl&) = delete;
~MonoVadImpl() = default;
int SampleRateHz() const override { return rnn_vad::kSampleRate24kHz; }
void Reset() override { rnn_vad_.Reset(); }
float Analyze(rtc::ArrayView<const float> frame) override {
RTC_DCHECK_EQ(frame.size(), rnn_vad::kFrameSize10ms24kHz);
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
/*samples=*/{frame.data(), rnn_vad::kFrameSize10ms24kHz},
feature_vector);
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
}
private:
rnn_vad::FeaturesExtractor features_extractor_;
rnn_vad::RnnVad rnn_vad_;
};
} // namespace
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
const AvailableCpuFeatures& cpu_features,
int sample_rate_hz)
: VoiceActivityDetectorWrapper(kVadResetPeriodMs,
cpu_features,
sample_rate_hz) {}
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features,
int sample_rate_hz)
: VoiceActivityDetectorWrapper(vad_reset_period_ms,
std::make_unique<MonoVadImpl>(cpu_features),
sample_rate_hz) {}
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms,
std::unique_ptr<MonoVad> vad,
int sample_rate_hz)
: vad_reset_period_frames_(
rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
time_to_vad_reset_(vad_reset_period_frames_),
vad_(std::move(vad)) {
RTC_DCHECK(vad_);
RTC_DCHECK_GT(vad_reset_period_frames_, 1);
resampled_buffer_.resize(
rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
Initialize(sample_rate_hz);
}
VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
void VoiceActivityDetectorWrapper::Initialize(int sample_rate_hz) {
RTC_DCHECK_GT(sample_rate_hz, 0);
frame_size_ = rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond);
int status =
resampler_.InitializeIfNeeded(sample_rate_hz, vad_->SampleRateHz(),
/*num_channels=*/1);
constexpr int kStatusOk = 0;
RTC_DCHECK_EQ(status, kStatusOk);
vad_->Reset();
}
float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
// Periodically reset the VAD.
time_to_vad_reset_--;
if (time_to_vad_reset_ <= 0) {
vad_->Reset();
time_to_vad_reset_ = vad_reset_period_frames_;
}
// Resample the first channel of `frame`.
RTC_DCHECK_EQ(frame.samples_per_channel(), frame_size_);
resampler_.Resample(frame.channel(0).data(), frame_size_,
resampled_buffer_.data(), resampled_buffer_.size());
return vad_->Analyze(resampled_buffer_);
}
} // namespace webrtc

View File

@ -0,0 +1,82 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
// Wraps a single-channel Voice Activity Detector (VAD) which is used to analyze
// the first channel of the input audio frames. Takes care of resampling the
// input frames to match the sample rate of the wrapped VAD and periodically
// resets the VAD.
class VoiceActivityDetectorWrapper {
public:
// Single channel VAD interface.
class MonoVad {
public:
virtual ~MonoVad() = default;
// Returns the sample rate (Hz) required for the input frames analyzed by
// `ComputeProbability`.
virtual int SampleRateHz() const = 0;
// Resets the internal state.
virtual void Reset() = 0;
// Analyzes an audio frame and returns the speech probability.
virtual float Analyze(rtc::ArrayView<const float> frame) = 0;
};
// Ctor. Uses `cpu_features` to instantiate the default VAD.
VoiceActivityDetectorWrapper(const AvailableCpuFeatures& cpu_features,
int sample_rate_hz);
// Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
// `MonoVad::Reset()`; it must be equal to or greater than the duration of two
// frames. Uses `cpu_features` to instantiate the default VAD.
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features,
int sample_rate_hz);
// Ctor. Uses a custom `vad`.
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
std::unique_ptr<MonoVad> vad,
int sample_rate_hz);
VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
delete;
~VoiceActivityDetectorWrapper();
// Initializes the VAD wrapper.
void Initialize(int sample_rate_hz);
// Analyzes the first channel of `frame` and returns the speech probability.
// `frame` must be a 10 ms frame with the sample rate specified in the last
// `Initialize()` call.
float Analyze(AudioFrameView<const float> frame);
private:
const int vad_reset_period_frames_;
int frame_size_;
int time_to_vad_reset_;
PushResampler<float> resampler_;
std::unique_ptr<MonoVad> vad_;
std::vector<float> resampled_buffer_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_