Update to current webrtc library
This is from the upstream library commit id 3326535126e435f1ba647885ce43a8f0f3d317eb, corresponding to Chromium 88.0.4290.1.
This commit is contained in:
290
webrtc/modules/audio_processing/agc2/BUILD.gn
Normal file
290
webrtc/modules/audio_processing/agc2/BUILD.gn
Normal file
@ -0,0 +1,290 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
import("../../../webrtc.gni")
|
||||
|
||||
group("agc2") {
|
||||
deps = [
|
||||
":adaptive_digital",
|
||||
":fixed_digital",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("level_estimation_agc") {
|
||||
sources = [
|
||||
"adaptive_mode_level_estimator_agc.cc",
|
||||
"adaptive_mode_level_estimator_agc.h",
|
||||
]
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
deps = [
|
||||
":adaptive_digital",
|
||||
":common",
|
||||
":gain_applier",
|
||||
":noise_level_estimator",
|
||||
":rnn_vad_with_level",
|
||||
"..:api",
|
||||
"..:apm_logging",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../common_audio",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
"../../../rtc_base:safe_minmax",
|
||||
"../agc:level_estimation",
|
||||
"../vad",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("adaptive_digital") {
|
||||
sources = [
|
||||
"adaptive_agc.cc",
|
||||
"adaptive_agc.h",
|
||||
"adaptive_digital_gain_applier.cc",
|
||||
"adaptive_digital_gain_applier.h",
|
||||
"adaptive_mode_level_estimator.cc",
|
||||
"adaptive_mode_level_estimator.h",
|
||||
"saturation_protector.cc",
|
||||
"saturation_protector.h",
|
||||
]
|
||||
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
|
||||
deps = [
|
||||
":common",
|
||||
":gain_applier",
|
||||
":noise_level_estimator",
|
||||
":rnn_vad_with_level",
|
||||
"..:api",
|
||||
"..:apm_logging",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../common_audio",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:logging",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
"../../../rtc_base:safe_compare",
|
||||
"../../../rtc_base:safe_minmax",
|
||||
"../../../system_wrappers:metrics",
|
||||
]
|
||||
|
||||
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
|
||||
}
|
||||
|
||||
rtc_library("biquad_filter") {
|
||||
visibility = [ "./*" ]
|
||||
sources = [
|
||||
"biquad_filter.cc",
|
||||
"biquad_filter.h",
|
||||
]
|
||||
deps = [
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("common") {
|
||||
sources = [ "agc2_common.h" ]
|
||||
}
|
||||
|
||||
rtc_library("fixed_digital") {
|
||||
sources = [
|
||||
"fixed_digital_level_estimator.cc",
|
||||
"fixed_digital_level_estimator.h",
|
||||
"interpolated_gain_curve.cc",
|
||||
"interpolated_gain_curve.h",
|
||||
"limiter.cc",
|
||||
"limiter.h",
|
||||
]
|
||||
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
|
||||
deps = [
|
||||
":common",
|
||||
"..:apm_logging",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../common_audio",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:gtest_prod",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
"../../../rtc_base:safe_minmax",
|
||||
"../../../system_wrappers:metrics",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("gain_applier") {
|
||||
sources = [
|
||||
"gain_applier.cc",
|
||||
"gain_applier.h",
|
||||
]
|
||||
deps = [
|
||||
":common",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base:safe_minmax",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("noise_level_estimator") {
|
||||
sources = [
|
||||
"down_sampler.cc",
|
||||
"down_sampler.h",
|
||||
"noise_level_estimator.cc",
|
||||
"noise_level_estimator.h",
|
||||
"noise_spectrum_estimator.cc",
|
||||
"noise_spectrum_estimator.h",
|
||||
"signal_classifier.cc",
|
||||
"signal_classifier.h",
|
||||
]
|
||||
deps = [
|
||||
":biquad_filter",
|
||||
"..:apm_logging",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../common_audio",
|
||||
"../../../common_audio/third_party/ooura:fft_size_128",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:macromagic",
|
||||
"../../../system_wrappers",
|
||||
]
|
||||
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_with_level") {
|
||||
sources = [
|
||||
"vad_with_level.cc",
|
||||
"vad_with_level.h",
|
||||
]
|
||||
deps = [
|
||||
":common",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../common_audio",
|
||||
"../../../rtc_base:checks",
|
||||
"rnn_vad",
|
||||
"rnn_vad:rnn_vad_common",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("adaptive_digital_unittests") {
|
||||
testonly = true
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
|
||||
sources = [
|
||||
"adaptive_digital_gain_applier_unittest.cc",
|
||||
"adaptive_mode_level_estimator_unittest.cc",
|
||||
"gain_applier_unittest.cc",
|
||||
"saturation_protector_unittest.cc",
|
||||
]
|
||||
deps = [
|
||||
":adaptive_digital",
|
||||
":common",
|
||||
":gain_applier",
|
||||
":test_utils",
|
||||
"..:apm_logging",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../common_audio",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:gunit_helpers",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
"../../../test:test_support",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("biquad_filter_unittests") {
|
||||
testonly = true
|
||||
sources = [ "biquad_filter_unittest.cc" ]
|
||||
deps = [
|
||||
":biquad_filter",
|
||||
"../../../rtc_base:gunit_helpers",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("fixed_digital_unittests") {
|
||||
testonly = true
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
|
||||
sources = [
|
||||
"agc2_testing_common_unittest.cc",
|
||||
"compute_interpolated_gain_curve.cc",
|
||||
"compute_interpolated_gain_curve.h",
|
||||
"fixed_digital_level_estimator_unittest.cc",
|
||||
"interpolated_gain_curve_unittest.cc",
|
||||
"limiter_db_gain_curve.cc",
|
||||
"limiter_db_gain_curve.h",
|
||||
"limiter_db_gain_curve_unittest.cc",
|
||||
"limiter_unittest.cc",
|
||||
]
|
||||
deps = [
|
||||
":common",
|
||||
":fixed_digital",
|
||||
":test_utils",
|
||||
"..:apm_logging",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../common_audio",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:gunit_helpers",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
"../../../system_wrappers:metrics",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("noise_estimator_unittests") {
|
||||
testonly = true
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
|
||||
sources = [
|
||||
"noise_level_estimator_unittest.cc",
|
||||
"signal_classifier_unittest.cc",
|
||||
]
|
||||
deps = [
|
||||
":noise_level_estimator",
|
||||
":test_utils",
|
||||
"..:apm_logging",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:gunit_helpers",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_with_level_unittests") {
|
||||
testonly = true
|
||||
sources = [ "vad_with_level_unittest.cc" ]
|
||||
deps = [
|
||||
":common",
|
||||
":rnn_vad_with_level",
|
||||
"..:audio_frame_view",
|
||||
"../../../rtc_base:gunit_helpers",
|
||||
"../../../rtc_base:safe_compare",
|
||||
"../../../test:test_support",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("test_utils") {
|
||||
testonly = true
|
||||
visibility = [
|
||||
":*",
|
||||
"..:audio_processing_unittests",
|
||||
]
|
||||
sources = [
|
||||
"agc2_testing_common.cc",
|
||||
"agc2_testing_common.h",
|
||||
"vector_float_frame.cc",
|
||||
"vector_float_frame.h",
|
||||
]
|
||||
deps = [
|
||||
"..:audio_frame_view",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
]
|
||||
}
|
90
webrtc/modules/audio_processing/agc2/adaptive_agc.cc
Normal file
90
webrtc/modules/audio_processing/agc2/adaptive_agc.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/adaptive_agc.h"
|
||||
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc2/vad_with_level.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
void DumpDebugData(const AdaptiveDigitalGainApplier::FrameInfo& info,
|
||||
ApmDataDumper& dumper) {
|
||||
dumper.DumpRaw("agc2_vad_probability", info.vad_result.speech_probability);
|
||||
dumper.DumpRaw("agc2_vad_rms_dbfs", info.vad_result.rms_dbfs);
|
||||
dumper.DumpRaw("agc2_vad_peak_dbfs", info.vad_result.peak_dbfs);
|
||||
dumper.DumpRaw("agc2_noise_estimate_dbfs", info.input_noise_level_dbfs);
|
||||
dumper.DumpRaw("agc2_last_limiter_audio_level", info.limiter_envelope_dbfs);
|
||||
}
|
||||
|
||||
constexpr int kGainApplierAdjacentSpeechFramesThreshold = 1;
|
||||
constexpr float kMaxGainChangePerSecondDb = 3.f;
|
||||
constexpr float kMaxOutputNoiseLevelDbfs = -50.f;
|
||||
|
||||
} // namespace
|
||||
|
||||
AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper)
|
||||
: speech_level_estimator_(apm_data_dumper),
|
||||
gain_applier_(apm_data_dumper,
|
||||
kGainApplierAdjacentSpeechFramesThreshold,
|
||||
kMaxGainChangePerSecondDb,
|
||||
kMaxOutputNoiseLevelDbfs),
|
||||
apm_data_dumper_(apm_data_dumper),
|
||||
noise_level_estimator_(apm_data_dumper) {
|
||||
RTC_DCHECK(apm_data_dumper);
|
||||
}
|
||||
|
||||
AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper,
|
||||
const AudioProcessing::Config::GainController2& config)
|
||||
: speech_level_estimator_(
|
||||
apm_data_dumper,
|
||||
config.adaptive_digital.level_estimator,
|
||||
config.adaptive_digital
|
||||
.level_estimator_adjacent_speech_frames_threshold,
|
||||
config.adaptive_digital.initial_saturation_margin_db,
|
||||
config.adaptive_digital.extra_saturation_margin_db),
|
||||
vad_(config.adaptive_digital.vad_probability_attack),
|
||||
gain_applier_(
|
||||
apm_data_dumper,
|
||||
config.adaptive_digital.gain_applier_adjacent_speech_frames_threshold,
|
||||
config.adaptive_digital.max_gain_change_db_per_second,
|
||||
config.adaptive_digital.max_output_noise_level_dbfs),
|
||||
apm_data_dumper_(apm_data_dumper),
|
||||
noise_level_estimator_(apm_data_dumper) {
|
||||
RTC_DCHECK(apm_data_dumper);
|
||||
if (!config.adaptive_digital.use_saturation_protector) {
|
||||
RTC_LOG(LS_WARNING) << "The saturation protector cannot be disabled.";
|
||||
}
|
||||
}
|
||||
|
||||
AdaptiveAgc::~AdaptiveAgc() = default;
|
||||
|
||||
void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
|
||||
AdaptiveDigitalGainApplier::FrameInfo info;
|
||||
info.vad_result = vad_.AnalyzeFrame(frame);
|
||||
speech_level_estimator_.Update(info.vad_result);
|
||||
info.input_level_dbfs = speech_level_estimator_.level_dbfs();
|
||||
info.input_noise_level_dbfs = noise_level_estimator_.Analyze(frame);
|
||||
info.limiter_envelope_dbfs =
|
||||
limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.f;
|
||||
info.estimate_is_confident = speech_level_estimator_.IsConfident();
|
||||
DumpDebugData(info, *apm_data_dumper_);
|
||||
gain_applier_.Process(info, frame);
|
||||
}
|
||||
|
||||
void AdaptiveAgc::Reset() {
|
||||
speech_level_estimator_.Reset();
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
50
webrtc/modules/audio_processing/agc2/adaptive_agc.h
Normal file
50
webrtc/modules/audio_processing/agc2/adaptive_agc.h
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
|
||||
|
||||
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
|
||||
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
|
||||
#include "modules/audio_processing/agc2/noise_level_estimator.h"
|
||||
#include "modules/audio_processing/agc2/vad_with_level.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
#include "modules/audio_processing/include/audio_processing.h"
|
||||
|
||||
namespace webrtc {
|
||||
class ApmDataDumper;
|
||||
|
||||
// Adaptive digital gain controller.
|
||||
// TODO(crbug.com/webrtc/7494): Unify with `AdaptiveDigitalGainApplier`.
|
||||
class AdaptiveAgc {
|
||||
public:
|
||||
explicit AdaptiveAgc(ApmDataDumper* apm_data_dumper);
|
||||
// TODO(crbug.com/webrtc/7494): Remove ctor above.
|
||||
AdaptiveAgc(ApmDataDumper* apm_data_dumper,
|
||||
const AudioProcessing::Config::GainController2& config);
|
||||
~AdaptiveAgc();
|
||||
|
||||
// Analyzes `frame` and applies a digital adaptive gain to it. Takes into
|
||||
// account the envelope measured by the limiter.
|
||||
// TODO(crbug.com/webrtc/7494): Make the class depend on the limiter.
|
||||
void Process(AudioFrameView<float> frame, float limiter_envelope);
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
AdaptiveModeLevelEstimator speech_level_estimator_;
|
||||
VadLevelAnalyzer vad_;
|
||||
AdaptiveDigitalGainApplier gain_applier_;
|
||||
ApmDataDumper* const apm_data_dumper_;
|
||||
NoiseLevelEstimator noise_level_estimator_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
|
@ -0,0 +1,179 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
#include "system_wrappers/include/metrics.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
// This function maps input level to desired applied gain. We want to
|
||||
// boost the signal so that peaks are at -kHeadroomDbfs. We can't
|
||||
// apply more than kMaxGainDb gain.
|
||||
float ComputeGainDb(float input_level_dbfs) {
|
||||
// If the level is very low, boost it as much as we can.
|
||||
if (input_level_dbfs < -(kHeadroomDbfs + kMaxGainDb)) {
|
||||
return kMaxGainDb;
|
||||
}
|
||||
|
||||
// We expect to end up here most of the time: the level is below
|
||||
// -headroom, but we can boost it to -headroom.
|
||||
if (input_level_dbfs < -kHeadroomDbfs) {
|
||||
return -kHeadroomDbfs - input_level_dbfs;
|
||||
}
|
||||
|
||||
// Otherwise, the level is too high and we can't boost. The
|
||||
// LevelEstimator is responsible for not reporting bogus gain
|
||||
// values.
|
||||
RTC_DCHECK_LE(input_level_dbfs, 0.f);
|
||||
return 0.f;
|
||||
}
|
||||
|
||||
// Returns `target_gain` if the output noise level is below
|
||||
// `max_output_noise_level_dbfs`; otherwise returns a capped gain so that the
|
||||
// output noise level equals `max_output_noise_level_dbfs`.
|
||||
float LimitGainByNoise(float target_gain,
|
||||
float input_noise_level_dbfs,
|
||||
float max_output_noise_level_dbfs,
|
||||
ApmDataDumper& apm_data_dumper) {
|
||||
const float noise_headroom_db =
|
||||
max_output_noise_level_dbfs - input_noise_level_dbfs;
|
||||
apm_data_dumper.DumpRaw("agc2_noise_headroom_db", noise_headroom_db);
|
||||
return std::min(target_gain, std::max(noise_headroom_db, 0.f));
|
||||
}
|
||||
|
||||
float LimitGainByLowConfidence(float target_gain,
|
||||
float last_gain,
|
||||
float limiter_audio_level_dbfs,
|
||||
bool estimate_is_confident) {
|
||||
if (estimate_is_confident ||
|
||||
limiter_audio_level_dbfs <= kLimiterThresholdForAgcGainDbfs) {
|
||||
return target_gain;
|
||||
}
|
||||
const float limiter_level_before_gain = limiter_audio_level_dbfs - last_gain;
|
||||
|
||||
// Compute a new gain so that limiter_level_before_gain + new_gain <=
|
||||
// kLimiterThreshold.
|
||||
const float new_target_gain = std::max(
|
||||
kLimiterThresholdForAgcGainDbfs - limiter_level_before_gain, 0.f);
|
||||
return std::min(new_target_gain, target_gain);
|
||||
}
|
||||
|
||||
// Computes how the gain should change during this frame.
|
||||
// Return the gain difference in db to 'last_gain_db'.
|
||||
float ComputeGainChangeThisFrameDb(float target_gain_db,
|
||||
float last_gain_db,
|
||||
bool gain_increase_allowed,
|
||||
float max_gain_change_db) {
|
||||
float target_gain_difference_db = target_gain_db - last_gain_db;
|
||||
if (!gain_increase_allowed) {
|
||||
target_gain_difference_db = std::min(target_gain_difference_db, 0.f);
|
||||
}
|
||||
return rtc::SafeClamp(target_gain_difference_db, -max_gain_change_db,
|
||||
max_gain_change_db);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AdaptiveDigitalGainApplier::AdaptiveDigitalGainApplier(
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
int adjacent_speech_frames_threshold,
|
||||
float max_gain_change_db_per_second,
|
||||
float max_output_noise_level_dbfs)
|
||||
: apm_data_dumper_(apm_data_dumper),
|
||||
gain_applier_(
|
||||
/*hard_clip_samples=*/false,
|
||||
/*initial_gain_factor=*/DbToRatio(kInitialAdaptiveDigitalGainDb)),
|
||||
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
|
||||
max_gain_change_db_per_10ms_(max_gain_change_db_per_second *
|
||||
kFrameDurationMs / 1000.f),
|
||||
max_output_noise_level_dbfs_(max_output_noise_level_dbfs),
|
||||
calls_since_last_gain_log_(0),
|
||||
frames_to_gain_increase_allowed_(adjacent_speech_frames_threshold_),
|
||||
last_gain_db_(kInitialAdaptiveDigitalGainDb) {
|
||||
RTC_DCHECK_GT(max_gain_change_db_per_second, 0.f);
|
||||
RTC_DCHECK_GE(frames_to_gain_increase_allowed_, 1);
|
||||
RTC_DCHECK_GE(max_output_noise_level_dbfs_, -90.f);
|
||||
RTC_DCHECK_LE(max_output_noise_level_dbfs_, 0.f);
|
||||
}
|
||||
|
||||
void AdaptiveDigitalGainApplier::Process(const FrameInfo& info,
|
||||
AudioFrameView<float> frame) {
|
||||
RTC_DCHECK_GE(info.input_level_dbfs, -150.f);
|
||||
RTC_DCHECK_GE(frame.num_channels(), 1);
|
||||
RTC_DCHECK(
|
||||
frame.samples_per_channel() == 80 || frame.samples_per_channel() == 160 ||
|
||||
frame.samples_per_channel() == 320 || frame.samples_per_channel() == 480)
|
||||
<< "`frame` does not look like a 10 ms frame for an APM supported sample "
|
||||
"rate";
|
||||
|
||||
const float target_gain_db = LimitGainByLowConfidence(
|
||||
LimitGainByNoise(ComputeGainDb(std::min(info.input_level_dbfs, 0.f)),
|
||||
info.input_noise_level_dbfs,
|
||||
max_output_noise_level_dbfs_, *apm_data_dumper_),
|
||||
last_gain_db_, info.limiter_envelope_dbfs, info.estimate_is_confident);
|
||||
|
||||
// Forbid increasing the gain until enough adjacent speech frames are
|
||||
// observed.
|
||||
if (info.vad_result.speech_probability < kVadConfidenceThreshold) {
|
||||
frames_to_gain_increase_allowed_ = adjacent_speech_frames_threshold_;
|
||||
} else if (frames_to_gain_increase_allowed_ > 0) {
|
||||
frames_to_gain_increase_allowed_--;
|
||||
}
|
||||
|
||||
const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb(
|
||||
target_gain_db, last_gain_db_,
|
||||
/*gain_increase_allowed=*/frames_to_gain_increase_allowed_ == 0,
|
||||
max_gain_change_db_per_10ms_);
|
||||
|
||||
apm_data_dumper_->DumpRaw("agc2_want_to_change_by_db",
|
||||
target_gain_db - last_gain_db_);
|
||||
apm_data_dumper_->DumpRaw("agc2_will_change_by_db",
|
||||
gain_change_this_frame_db);
|
||||
|
||||
// Optimization: avoid calling math functions if gain does not
|
||||
// change.
|
||||
if (gain_change_this_frame_db != 0.f) {
|
||||
gain_applier_.SetGainFactor(
|
||||
DbToRatio(last_gain_db_ + gain_change_this_frame_db));
|
||||
}
|
||||
gain_applier_.ApplyGain(frame);
|
||||
|
||||
// Remember that the gain has changed for the next iteration.
|
||||
last_gain_db_ = last_gain_db_ + gain_change_this_frame_db;
|
||||
apm_data_dumper_->DumpRaw("agc2_applied_gain_db", last_gain_db_);
|
||||
|
||||
// Log every 10 seconds.
|
||||
calls_since_last_gain_log_++;
|
||||
if (calls_since_last_gain_log_ == 1000) {
|
||||
calls_since_last_gain_log_ = 0;
|
||||
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.DigitalGainApplied",
|
||||
last_gain_db_, 0, kMaxGainDb, kMaxGainDb + 1);
|
||||
RTC_HISTOGRAM_COUNTS_LINEAR(
|
||||
"WebRTC.Audio.Agc2.EstimatedSpeechPlusNoiseLevel",
|
||||
-info.input_level_dbfs, 0, 100, 101);
|
||||
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedNoiseLevel",
|
||||
-info.input_noise_level_dbfs, 0, 100, 101);
|
||||
RTC_LOG(LS_INFO) << "AGC2 adaptive digital"
|
||||
<< " | speech_plus_noise_dbfs: " << info.input_level_dbfs
|
||||
<< " | noise_dbfs: " << info.input_noise_level_dbfs
|
||||
<< " | gain_db: " << last_gain_db_;
|
||||
}
|
||||
}
|
||||
} // namespace webrtc
|
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_
|
||||
|
||||
#include "modules/audio_processing/agc2/gain_applier.h"
|
||||
#include "modules/audio_processing/agc2/vad_with_level.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class ApmDataDumper;
|
||||
|
||||
// Part of the adaptive digital controller that applies a digital adaptive gain.
|
||||
// The gain is updated towards a target. The logic decides when gain updates are
|
||||
// allowed, it controls the adaptation speed and caps the target based on the
|
||||
// estimated noise level and the speech level estimate confidence.
|
||||
class AdaptiveDigitalGainApplier {
|
||||
public:
|
||||
// Information about a frame to process.
|
||||
struct FrameInfo {
|
||||
float input_level_dbfs; // Estimated speech plus noise level.
|
||||
float input_noise_level_dbfs; // Estimated noise level.
|
||||
VadLevelAnalyzer::Result vad_result;
|
||||
float limiter_envelope_dbfs; // Envelope level from the limiter.
|
||||
bool estimate_is_confident;
|
||||
};
|
||||
|
||||
// Ctor.
|
||||
// `adjacent_speech_frames_threshold` indicates how many speech frames are
|
||||
// required before a gain increase is allowed. `max_gain_change_db_per_second`
|
||||
// limits the adaptation speed (uniformly operated across frames).
|
||||
// `max_output_noise_level_dbfs` limits the output noise level.
|
||||
AdaptiveDigitalGainApplier(ApmDataDumper* apm_data_dumper,
|
||||
int adjacent_speech_frames_threshold,
|
||||
float max_gain_change_db_per_second,
|
||||
float max_output_noise_level_dbfs);
|
||||
AdaptiveDigitalGainApplier(const AdaptiveDigitalGainApplier&) = delete;
|
||||
AdaptiveDigitalGainApplier& operator=(const AdaptiveDigitalGainApplier&) =
|
||||
delete;
|
||||
|
||||
// Analyzes `info`, updates the digital gain and applies it to a 10 ms
|
||||
// `frame`. Supports any sample rate supported by APM.
|
||||
void Process(const FrameInfo& info, AudioFrameView<float> frame);
|
||||
|
||||
private:
|
||||
ApmDataDumper* const apm_data_dumper_;
|
||||
GainApplier gain_applier_;
|
||||
|
||||
const int adjacent_speech_frames_threshold_;
|
||||
const float max_gain_change_db_per_10ms_;
|
||||
const float max_output_noise_level_dbfs_;
|
||||
|
||||
int calls_since_last_gain_log_;
|
||||
int frames_to_gain_increase_allowed_;
|
||||
float last_gain_db_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_
|
@ -0,0 +1,198 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
using LevelEstimatorType =
|
||||
AudioProcessing::Config::GainController2::LevelEstimator;
|
||||
|
||||
// Combines a level estimation with the saturation protector margins.
|
||||
float ComputeLevelEstimateDbfs(float level_estimate_dbfs,
|
||||
float saturation_margin_db,
|
||||
float extra_saturation_margin_db) {
|
||||
return rtc::SafeClamp<float>(
|
||||
level_estimate_dbfs + saturation_margin_db + extra_saturation_margin_db,
|
||||
-90.f, 30.f);
|
||||
}
|
||||
|
||||
// Returns the level of given type from `vad_level`.
|
||||
float GetLevel(const VadLevelAnalyzer::Result& vad_level,
|
||||
LevelEstimatorType type) {
|
||||
switch (type) {
|
||||
case LevelEstimatorType::kRms:
|
||||
return vad_level.rms_dbfs;
|
||||
break;
|
||||
case LevelEstimatorType::kPeak:
|
||||
return vad_level.peak_dbfs;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool AdaptiveModeLevelEstimator::LevelEstimatorState::operator==(
|
||||
const AdaptiveModeLevelEstimator::LevelEstimatorState& b) const {
|
||||
return time_to_full_buffer_ms == b.time_to_full_buffer_ms &&
|
||||
level_dbfs.numerator == b.level_dbfs.numerator &&
|
||||
level_dbfs.denominator == b.level_dbfs.denominator &&
|
||||
saturation_protector == b.saturation_protector;
|
||||
}
|
||||
|
||||
float AdaptiveModeLevelEstimator::LevelEstimatorState::Ratio::GetRatio() const {
|
||||
RTC_DCHECK_NE(denominator, 0.f);
|
||||
return numerator / denominator;
|
||||
}
|
||||
|
||||
AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
|
||||
ApmDataDumper* apm_data_dumper)
|
||||
: AdaptiveModeLevelEstimator(
|
||||
apm_data_dumper,
|
||||
AudioProcessing::Config::GainController2::LevelEstimator::kRms,
|
||||
kDefaultLevelEstimatorAdjacentSpeechFramesThreshold,
|
||||
kDefaultInitialSaturationMarginDb,
|
||||
kDefaultExtraSaturationMarginDb) {}
|
||||
|
||||
AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
AudioProcessing::Config::GainController2::LevelEstimator level_estimator,
|
||||
int adjacent_speech_frames_threshold,
|
||||
float initial_saturation_margin_db,
|
||||
float extra_saturation_margin_db)
|
||||
: apm_data_dumper_(apm_data_dumper),
|
||||
level_estimator_type_(level_estimator),
|
||||
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
|
||||
initial_saturation_margin_db_(initial_saturation_margin_db),
|
||||
extra_saturation_margin_db_(extra_saturation_margin_db),
|
||||
level_dbfs_(ComputeLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs,
|
||||
initial_saturation_margin_db_,
|
||||
extra_saturation_margin_db_)) {
|
||||
RTC_DCHECK(apm_data_dumper_);
|
||||
RTC_DCHECK_GE(adjacent_speech_frames_threshold_, 1);
|
||||
Reset();
|
||||
}
|
||||
|
||||
void AdaptiveModeLevelEstimator::Update(
|
||||
const VadLevelAnalyzer::Result& vad_level) {
|
||||
RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f);
|
||||
RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f);
|
||||
RTC_DCHECK_GT(vad_level.peak_dbfs, -150.f);
|
||||
RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f);
|
||||
RTC_DCHECK_GE(vad_level.speech_probability, 0.f);
|
||||
RTC_DCHECK_LE(vad_level.speech_probability, 1.f);
|
||||
DumpDebugData();
|
||||
|
||||
if (vad_level.speech_probability < kVadConfidenceThreshold) {
|
||||
// Not a speech frame.
|
||||
if (adjacent_speech_frames_threshold_ > 1) {
|
||||
// When two or more adjacent speech frames are required in order to update
|
||||
// the state, we need to decide whether to discard or confirm the updates
|
||||
// based on the speech sequence length.
|
||||
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
|
||||
// First non-speech frame after a long enough sequence of speech frames.
|
||||
// Update the reliable state.
|
||||
reliable_state_ = preliminary_state_;
|
||||
} else if (num_adjacent_speech_frames_ > 0) {
|
||||
// First non-speech frame after a too short sequence of speech frames.
|
||||
// Reset to the last reliable state.
|
||||
preliminary_state_ = reliable_state_;
|
||||
}
|
||||
}
|
||||
num_adjacent_speech_frames_ = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
// Speech frame observed.
|
||||
num_adjacent_speech_frames_++;
|
||||
|
||||
// Update preliminary level estimate.
|
||||
RTC_DCHECK_GE(preliminary_state_.time_to_full_buffer_ms, 0);
|
||||
const bool buffer_is_full = preliminary_state_.time_to_full_buffer_ms == 0;
|
||||
if (!buffer_is_full) {
|
||||
preliminary_state_.time_to_full_buffer_ms -= kFrameDurationMs;
|
||||
}
|
||||
// Weighted average of levels with speech probability as weight.
|
||||
RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
|
||||
const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f;
|
||||
preliminary_state_.level_dbfs.numerator =
|
||||
preliminary_state_.level_dbfs.numerator * leak_factor +
|
||||
GetLevel(vad_level, level_estimator_type_) * vad_level.speech_probability;
|
||||
preliminary_state_.level_dbfs.denominator =
|
||||
preliminary_state_.level_dbfs.denominator * leak_factor +
|
||||
vad_level.speech_probability;
|
||||
|
||||
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
|
||||
|
||||
UpdateSaturationProtectorState(vad_level.peak_dbfs, level_dbfs,
|
||||
preliminary_state_.saturation_protector);
|
||||
|
||||
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
|
||||
// `preliminary_state_` is now reliable. Update the last level estimation.
|
||||
level_dbfs_ = ComputeLevelEstimateDbfs(
|
||||
level_dbfs, preliminary_state_.saturation_protector.margin_db,
|
||||
extra_saturation_margin_db_);
|
||||
}
|
||||
}
|
||||
|
||||
bool AdaptiveModeLevelEstimator::IsConfident() const {
|
||||
if (adjacent_speech_frames_threshold_ == 1) {
|
||||
// Ignore `reliable_state_` when a single frame is enough to update the
|
||||
// level estimate (because it is not used).
|
||||
return preliminary_state_.time_to_full_buffer_ms == 0;
|
||||
}
|
||||
// Once confident, it remains confident.
|
||||
RTC_DCHECK(reliable_state_.time_to_full_buffer_ms != 0 ||
|
||||
preliminary_state_.time_to_full_buffer_ms == 0);
|
||||
// During the first long enough speech sequence, `reliable_state_` must be
|
||||
// ignored since `preliminary_state_` is used.
|
||||
return reliable_state_.time_to_full_buffer_ms == 0 ||
|
||||
(num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_ &&
|
||||
preliminary_state_.time_to_full_buffer_ms == 0);
|
||||
}
|
||||
|
||||
void AdaptiveModeLevelEstimator::Reset() {
|
||||
ResetLevelEstimatorState(preliminary_state_);
|
||||
ResetLevelEstimatorState(reliable_state_);
|
||||
level_dbfs_ = ComputeLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs,
|
||||
initial_saturation_margin_db_,
|
||||
extra_saturation_margin_db_);
|
||||
num_adjacent_speech_frames_ = 0;
|
||||
}
|
||||
|
||||
void AdaptiveModeLevelEstimator::ResetLevelEstimatorState(
|
||||
LevelEstimatorState& state) const {
|
||||
state.time_to_full_buffer_ms = kFullBufferSizeMs;
|
||||
state.level_dbfs.numerator = 0.f;
|
||||
state.level_dbfs.denominator = 0.f;
|
||||
ResetSaturationProtectorState(initial_saturation_margin_db_,
|
||||
state.saturation_protector);
|
||||
}
|
||||
|
||||
void AdaptiveModeLevelEstimator::DumpDebugData() const {
|
||||
apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs", level_dbfs_);
|
||||
apm_data_dumper_->DumpRaw("agc2_adaptive_num_adjacent_speech_frames_",
|
||||
num_adjacent_speech_frames_);
|
||||
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_num",
|
||||
preliminary_state_.level_dbfs.numerator);
|
||||
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_den",
|
||||
preliminary_state_.level_dbfs.denominator);
|
||||
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_saturation_margin_db",
|
||||
preliminary_state_.saturation_protector.margin_db);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
@ -0,0 +1,86 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/agc2/saturation_protector.h"
|
||||
#include "modules/audio_processing/agc2/vad_with_level.h"
|
||||
#include "modules/audio_processing/include/audio_processing.h"
|
||||
|
||||
namespace webrtc {
|
||||
class ApmDataDumper;
|
||||
|
||||
// Level estimator for the digital adaptive gain controller.
|
||||
class AdaptiveModeLevelEstimator {
|
||||
public:
|
||||
explicit AdaptiveModeLevelEstimator(ApmDataDumper* apm_data_dumper);
|
||||
AdaptiveModeLevelEstimator(const AdaptiveModeLevelEstimator&) = delete;
|
||||
AdaptiveModeLevelEstimator& operator=(const AdaptiveModeLevelEstimator&) =
|
||||
delete;
|
||||
AdaptiveModeLevelEstimator(
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
AudioProcessing::Config::GainController2::LevelEstimator level_estimator,
|
||||
int adjacent_speech_frames_threshold,
|
||||
float initial_saturation_margin_db,
|
||||
float extra_saturation_margin_db);
|
||||
|
||||
// Updates the level estimation.
|
||||
void Update(const VadLevelAnalyzer::Result& vad_data);
|
||||
// Returns the estimated speech plus noise level.
|
||||
float level_dbfs() const { return level_dbfs_; }
|
||||
// Returns true if the estimator is confident on its current estimate.
|
||||
bool IsConfident() const;
|
||||
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
// Part of the level estimator state used for check-pointing and restore ops.
|
||||
struct LevelEstimatorState {
|
||||
bool operator==(const LevelEstimatorState& s) const;
|
||||
inline bool operator!=(const LevelEstimatorState& s) const {
|
||||
return !(*this == s);
|
||||
}
|
||||
struct Ratio {
|
||||
float numerator;
|
||||
float denominator;
|
||||
float GetRatio() const;
|
||||
};
|
||||
// TODO(crbug.com/webrtc/7494): Remove time_to_full_buffer_ms if redundant.
|
||||
int time_to_full_buffer_ms;
|
||||
Ratio level_dbfs;
|
||||
SaturationProtectorState saturation_protector;
|
||||
};
|
||||
static_assert(std::is_trivially_copyable<LevelEstimatorState>::value, "");
|
||||
|
||||
void ResetLevelEstimatorState(LevelEstimatorState& state) const;
|
||||
|
||||
void DumpDebugData() const;
|
||||
|
||||
ApmDataDumper* const apm_data_dumper_;
|
||||
|
||||
const AudioProcessing::Config::GainController2::LevelEstimator
|
||||
level_estimator_type_;
|
||||
const int adjacent_speech_frames_threshold_;
|
||||
const float initial_saturation_margin_db_;
|
||||
const float extra_saturation_margin_db_;
|
||||
LevelEstimatorState preliminary_state_;
|
||||
LevelEstimatorState reliable_state_;
|
||||
float level_dbfs_;
|
||||
int num_adjacent_speech_frames_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
|
@ -0,0 +1,65 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
AdaptiveModeLevelEstimatorAgc::AdaptiveModeLevelEstimatorAgc(
|
||||
ApmDataDumper* apm_data_dumper)
|
||||
: level_estimator_(apm_data_dumper) {
|
||||
set_target_level_dbfs(kDefaultAgc2LevelHeadroomDbfs);
|
||||
}
|
||||
|
||||
// |audio| must be mono; in a multi-channel stream, provide the first (usually
|
||||
// left) channel.
|
||||
void AdaptiveModeLevelEstimatorAgc::Process(const int16_t* audio,
|
||||
size_t length,
|
||||
int sample_rate_hz) {
|
||||
std::vector<float> float_audio_frame(audio, audio + length);
|
||||
const float* const first_channel = &float_audio_frame[0];
|
||||
AudioFrameView<const float> frame_view(&first_channel, 1 /* num channels */,
|
||||
length);
|
||||
const auto vad_prob = agc2_vad_.AnalyzeFrame(frame_view);
|
||||
latest_voice_probability_ = vad_prob.speech_probability;
|
||||
if (latest_voice_probability_ > kVadConfidenceThreshold) {
|
||||
time_in_ms_since_last_estimate_ += kFrameDurationMs;
|
||||
}
|
||||
level_estimator_.Update(vad_prob);
|
||||
}
|
||||
|
||||
// Retrieves the difference between the target RMS level and the current
|
||||
// signal RMS level in dB. Returns true if an update is available and false
|
||||
// otherwise, in which case |error| should be ignored and no action taken.
|
||||
bool AdaptiveModeLevelEstimatorAgc::GetRmsErrorDb(int* error) {
|
||||
if (time_in_ms_since_last_estimate_ <= kTimeUntilConfidentMs) {
|
||||
return false;
|
||||
}
|
||||
*error =
|
||||
std::floor(target_level_dbfs() - level_estimator_.level_dbfs() + 0.5f);
|
||||
time_in_ms_since_last_estimate_ = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
void AdaptiveModeLevelEstimatorAgc::Reset() {
|
||||
level_estimator_.Reset();
|
||||
}
|
||||
|
||||
float AdaptiveModeLevelEstimatorAgc::voice_probability() const {
|
||||
return latest_voice_probability_;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "modules/audio_processing/agc/agc.h"
|
||||
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
|
||||
#include "modules/audio_processing/agc2/saturation_protector.h"
|
||||
#include "modules/audio_processing/agc2/vad_with_level.h"
|
||||
|
||||
namespace webrtc {
|
||||
class AdaptiveModeLevelEstimatorAgc : public Agc {
|
||||
public:
|
||||
explicit AdaptiveModeLevelEstimatorAgc(ApmDataDumper* apm_data_dumper);
|
||||
|
||||
// |audio| must be mono; in a multi-channel stream, provide the first (usually
|
||||
// left) channel.
|
||||
void Process(const int16_t* audio,
|
||||
size_t length,
|
||||
int sample_rate_hz) override;
|
||||
|
||||
// Retrieves the difference between the target RMS level and the current
|
||||
// signal RMS level in dB. Returns true if an update is available and false
|
||||
// otherwise, in which case |error| should be ignored and no action taken.
|
||||
bool GetRmsErrorDb(int* error) override;
|
||||
void Reset() override;
|
||||
|
||||
float voice_probability() const override;
|
||||
|
||||
private:
|
||||
static constexpr int kTimeUntilConfidentMs = 700;
|
||||
static constexpr int kDefaultAgc2LevelHeadroomDbfs = -1;
|
||||
int32_t time_in_ms_since_last_estimate_ = 0;
|
||||
AdaptiveModeLevelEstimator level_estimator_;
|
||||
VadLevelAnalyzer agc2_vad_;
|
||||
float latest_voice_probability_ = 0.f;
|
||||
};
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_
|
86
webrtc/modules/audio_processing/agc2/agc2_common.h
Normal file
86
webrtc/modules/audio_processing/agc2/agc2_common.h
Normal file
@ -0,0 +1,86 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
constexpr float kMinFloatS16Value = -32768.f;
|
||||
constexpr float kMaxFloatS16Value = 32767.f;
|
||||
constexpr float kMaxAbsFloatS16Value = 32768.0f;
|
||||
|
||||
constexpr size_t kFrameDurationMs = 10;
|
||||
constexpr size_t kSubFramesInFrame = 20;
|
||||
constexpr size_t kMaximalNumberOfSamplesPerChannel = 480;
|
||||
|
||||
constexpr float kAttackFilterConstant = 0.f;
|
||||
|
||||
// Adaptive digital gain applier settings below.
|
||||
constexpr float kHeadroomDbfs = 1.f;
|
||||
constexpr float kMaxGainDb = 30.f;
|
||||
constexpr float kInitialAdaptiveDigitalGainDb = 8.f;
|
||||
// At what limiter levels should we start decreasing the adaptive digital gain.
|
||||
constexpr float kLimiterThresholdForAgcGainDbfs = -kHeadroomDbfs;
|
||||
|
||||
// This is the threshold for speech. Speech frames are used for updating the
|
||||
// speech level, measuring the amount of speech, and decide when to allow target
|
||||
// gain reduction.
|
||||
constexpr float kVadConfidenceThreshold = 0.9f;
|
||||
|
||||
// The amount of 'memory' of the Level Estimator. Decides leak factors.
|
||||
constexpr size_t kFullBufferSizeMs = 1200;
|
||||
constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs;
|
||||
|
||||
constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
|
||||
|
||||
// Robust VAD probability and speech decisions.
|
||||
constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f;
|
||||
constexpr int kDefaultLevelEstimatorAdjacentSpeechFramesThreshold = 1;
|
||||
|
||||
// Saturation Protector settings.
|
||||
constexpr float kDefaultInitialSaturationMarginDb = 20.f;
|
||||
constexpr float kDefaultExtraSaturationMarginDb = 2.f;
|
||||
|
||||
constexpr size_t kPeakEnveloperSuperFrameLengthMs = 400;
|
||||
static_assert(kFullBufferSizeMs % kPeakEnveloperSuperFrameLengthMs == 0,
|
||||
"Full buffer size should be a multiple of super frame length for "
|
||||
"optimal Saturation Protector performance.");
|
||||
|
||||
constexpr size_t kPeakEnveloperBufferSize =
|
||||
kFullBufferSizeMs / kPeakEnveloperSuperFrameLengthMs + 1;
|
||||
|
||||
// This value is 10 ** (-1/20 * frame_size_ms / satproc_attack_ms),
|
||||
// where satproc_attack_ms is 5000.
|
||||
constexpr float kSaturationProtectorAttackConstant = 0.9988493699365052f;
|
||||
|
||||
// This value is 10 ** (-1/20 * frame_size_ms / satproc_decay_ms),
|
||||
// where satproc_decay_ms is 1000.
|
||||
constexpr float kSaturationProtectorDecayConstant = 0.9997697679981565f;
|
||||
|
||||
// This is computed from kDecayMs by
|
||||
// 10 ** (-1/20 * subframe_duration / kDecayMs).
|
||||
// |subframe_duration| is |kFrameDurationMs / kSubFramesInFrame|.
|
||||
// kDecayMs is defined in agc2_testing_common.h
|
||||
constexpr float kDecayFilterConstant = 0.9998848773724686f;
|
||||
|
||||
// Number of interpolation points for each region of the limiter.
|
||||
// These values have been tuned to limit the interpolated gain curve error given
|
||||
// the limiter parameters and allowing a maximum error of +/- 32768^-1.
|
||||
constexpr size_t kInterpolatedGainCurveKneePoints = 22;
|
||||
constexpr size_t kInterpolatedGainCurveBeyondKneePoints = 10;
|
||||
constexpr size_t kInterpolatedGainCurveTotalPoints =
|
||||
kInterpolatedGainCurveKneePoints + kInterpolatedGainCurveBeyondKneePoints;
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
|
33
webrtc/modules/audio_processing/agc2/agc2_testing_common.cc
Normal file
33
webrtc/modules/audio_processing/agc2/agc2_testing_common.cc
Normal file
@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_testing_common.h"
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace test {
|
||||
|
||||
std::vector<double> LinSpace(const double l,
|
||||
const double r,
|
||||
size_t num_points) {
|
||||
RTC_CHECK(num_points >= 2);
|
||||
std::vector<double> points(num_points);
|
||||
const double step = (r - l) / (num_points - 1.0);
|
||||
points[0] = l;
|
||||
for (size_t i = 1; i < num_points - 1; i++) {
|
||||
points[i] = static_cast<double>(l) + i * step;
|
||||
}
|
||||
points[num_points - 1] = r;
|
||||
return points;
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
78
webrtc/modules/audio_processing/agc2/agc2_testing_common.h
Normal file
78
webrtc/modules/audio_processing/agc2/agc2_testing_common.h
Normal file
@ -0,0 +1,78 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace test {
|
||||
|
||||
// Level Estimator test parameters.
|
||||
constexpr float kDecayMs = 500.f;
|
||||
|
||||
// Limiter parameters.
|
||||
constexpr float kLimiterMaxInputLevelDbFs = 1.f;
|
||||
constexpr float kLimiterKneeSmoothnessDb = 1.f;
|
||||
constexpr float kLimiterCompressionRatio = 5.f;
|
||||
constexpr float kPi = 3.1415926536f;
|
||||
|
||||
std::vector<double> LinSpace(const double l, const double r, size_t num_points);
|
||||
|
||||
class SineGenerator {
|
||||
public:
|
||||
SineGenerator(float frequency, int rate)
|
||||
: frequency_(frequency), rate_(rate) {}
|
||||
float operator()() {
|
||||
x_radians_ += frequency_ / rate_ * 2 * kPi;
|
||||
if (x_radians_ > 2 * kPi) {
|
||||
x_radians_ -= 2 * kPi;
|
||||
}
|
||||
return 1000.f * sinf(x_radians_);
|
||||
}
|
||||
|
||||
private:
|
||||
float frequency_;
|
||||
int rate_;
|
||||
float x_radians_ = 0.f;
|
||||
};
|
||||
|
||||
class PulseGenerator {
|
||||
public:
|
||||
PulseGenerator(float frequency, int rate)
|
||||
: samples_period_(
|
||||
static_cast<int>(static_cast<float>(rate) / frequency)) {
|
||||
RTC_DCHECK_GT(rate, frequency);
|
||||
}
|
||||
float operator()() {
|
||||
sample_counter_++;
|
||||
if (sample_counter_ >= samples_period_) {
|
||||
sample_counter_ -= samples_period_;
|
||||
}
|
||||
return static_cast<float>(
|
||||
sample_counter_ == 0 ? std::numeric_limits<int16_t>::max() : 10.f);
|
||||
}
|
||||
|
||||
private:
|
||||
int samples_period_;
|
||||
int sample_counter_ = 0;
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
|
36
webrtc/modules/audio_processing/agc2/biquad_filter.cc
Normal file
36
webrtc/modules/audio_processing/agc2/biquad_filter.cc
Normal file
@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/biquad_filter.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Transposed direct form I implementation of a bi-quad filter applied to an
|
||||
// input signal |x| to produce an output signal |y|.
|
||||
void BiQuadFilter::Process(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> y) {
|
||||
for (size_t k = 0; k < x.size(); ++k) {
|
||||
// Use temporary variable for x[k] to allow in-place function call
|
||||
// (that x and y refer to the same array).
|
||||
const float tmp = x[k];
|
||||
y[k] = coefficients_.b[0] * tmp + coefficients_.b[1] * biquad_state_.b[0] +
|
||||
coefficients_.b[2] * biquad_state_.b[1] -
|
||||
coefficients_.a[0] * biquad_state_.a[0] -
|
||||
coefficients_.a[1] * biquad_state_.a[1];
|
||||
biquad_state_.b[1] = biquad_state_.b[0];
|
||||
biquad_state_.b[0] = tmp;
|
||||
biquad_state_.a[1] = biquad_state_.a[0];
|
||||
biquad_state_.a[0] = y[k];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
66
webrtc/modules/audio_processing/agc2/biquad_filter.h
Normal file
66
webrtc/modules/audio_processing/agc2/biquad_filter.h
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/arraysize.h"
|
||||
#include "rtc_base/constructor_magic.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class BiQuadFilter {
|
||||
public:
|
||||
// Normalized filter coefficients.
|
||||
// b_0 + b_1 • z^(-1) + b_2 • z^(-2)
|
||||
// H(z) = ---------------------------------
|
||||
// 1 + a_1 • z^(-1) + a_2 • z^(-2)
|
||||
struct BiQuadCoefficients {
|
||||
float b[3];
|
||||
float a[2];
|
||||
};
|
||||
|
||||
BiQuadFilter() = default;
|
||||
|
||||
void Initialize(const BiQuadCoefficients& coefficients) {
|
||||
coefficients_ = coefficients;
|
||||
}
|
||||
|
||||
void Reset() { biquad_state_.Reset(); }
|
||||
|
||||
// Produces a filtered output y of the input x. Both x and y need to
|
||||
// have the same length. In-place modification is allowed.
|
||||
void Process(rtc::ArrayView<const float> x, rtc::ArrayView<float> y);
|
||||
|
||||
private:
|
||||
struct BiQuadState {
|
||||
BiQuadState() { Reset(); }
|
||||
|
||||
void Reset() {
|
||||
std::fill(b, b + arraysize(b), 0.f);
|
||||
std::fill(a, a + arraysize(a), 0.f);
|
||||
}
|
||||
|
||||
float b[2];
|
||||
float a[2];
|
||||
};
|
||||
|
||||
BiQuadState biquad_state_;
|
||||
BiQuadCoefficients coefficients_;
|
||||
|
||||
RTC_DISALLOW_COPY_AND_ASSIGN(BiQuadFilter);
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
|
@ -0,0 +1,229 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/compute_interpolated_gain_curve.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/agc2/agc2_testing_common.h"
|
||||
#include "modules/audio_processing/agc2/limiter_db_gain_curve.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
std::pair<double, double> ComputeLinearApproximationParams(
|
||||
const LimiterDbGainCurve* limiter,
|
||||
const double x) {
|
||||
const double m = limiter->GetGainFirstDerivativeLinear(x);
|
||||
const double q = limiter->GetGainLinear(x) - m * x;
|
||||
return {m, q};
|
||||
}
|
||||
|
||||
double ComputeAreaUnderPiecewiseLinearApproximation(
|
||||
const LimiterDbGainCurve* limiter,
|
||||
const double x0,
|
||||
const double x1) {
|
||||
RTC_CHECK_LT(x0, x1);
|
||||
|
||||
// Linear approximation in x0 and x1.
|
||||
double m0, q0, m1, q1;
|
||||
std::tie(m0, q0) = ComputeLinearApproximationParams(limiter, x0);
|
||||
std::tie(m1, q1) = ComputeLinearApproximationParams(limiter, x1);
|
||||
|
||||
// Intersection point between two adjacent linear pieces.
|
||||
RTC_CHECK_NE(m1, m0);
|
||||
const double x_split = (q0 - q1) / (m1 - m0);
|
||||
RTC_CHECK_LT(x0, x_split);
|
||||
RTC_CHECK_LT(x_split, x1);
|
||||
|
||||
auto area_under_linear_piece = [](double x_l, double x_r, double m,
|
||||
double q) {
|
||||
return x_r * (m * x_r / 2.0 + q) - x_l * (m * x_l / 2.0 + q);
|
||||
};
|
||||
return area_under_linear_piece(x0, x_split, m0, q0) +
|
||||
area_under_linear_piece(x_split, x1, m1, q1);
|
||||
}
|
||||
|
||||
// Computes the approximation error in the limiter region for a given interval.
|
||||
// The error is computed as the difference between the areas beneath the limiter
|
||||
// curve to approximate and its linear under-approximation.
|
||||
double LimiterUnderApproximationNegativeError(const LimiterDbGainCurve* limiter,
|
||||
const double x0,
|
||||
const double x1) {
|
||||
const double area_limiter = limiter->GetGainIntegralLinear(x0, x1);
|
||||
const double area_interpolated_curve =
|
||||
ComputeAreaUnderPiecewiseLinearApproximation(limiter, x0, x1);
|
||||
RTC_CHECK_GE(area_limiter, area_interpolated_curve);
|
||||
return area_limiter - area_interpolated_curve;
|
||||
}
|
||||
|
||||
// Automatically finds where to sample the beyond-knee region of a limiter using
|
||||
// a greedy optimization algorithm that iteratively decreases the approximation
|
||||
// error.
|
||||
// The solution is sub-optimal because the algorithm is greedy and the points
|
||||
// are assigned by halving intervals (starting with the whole beyond-knee region
|
||||
// as a single interval). However, even if sub-optimal, this algorithm works
|
||||
// well in practice and it is efficiently implemented using priority queues.
|
||||
std::vector<double> SampleLimiterRegion(const LimiterDbGainCurve* limiter) {
|
||||
static_assert(kInterpolatedGainCurveBeyondKneePoints > 2, "");
|
||||
|
||||
struct Interval {
|
||||
Interval() = default; // Ctor required by std::priority_queue.
|
||||
Interval(double l, double r, double e) : x0(l), x1(r), error(e) {
|
||||
RTC_CHECK(x0 < x1);
|
||||
}
|
||||
bool operator<(const Interval& other) const { return error < other.error; }
|
||||
|
||||
double x0;
|
||||
double x1;
|
||||
double error;
|
||||
};
|
||||
|
||||
std::priority_queue<Interval, std::vector<Interval>> q;
|
||||
q.emplace(limiter->limiter_start_linear(), limiter->max_input_level_linear(),
|
||||
LimiterUnderApproximationNegativeError(
|
||||
limiter, limiter->limiter_start_linear(),
|
||||
limiter->max_input_level_linear()));
|
||||
|
||||
// Iteratively find points by halving the interval with greatest error.
|
||||
while (q.size() < kInterpolatedGainCurveBeyondKneePoints) {
|
||||
// Get the interval with highest error.
|
||||
const auto interval = q.top();
|
||||
q.pop();
|
||||
|
||||
// Split |interval| and enqueue.
|
||||
double x_split = (interval.x0 + interval.x1) / 2.0;
|
||||
q.emplace(interval.x0, x_split,
|
||||
LimiterUnderApproximationNegativeError(limiter, interval.x0,
|
||||
x_split)); // Left.
|
||||
q.emplace(x_split, interval.x1,
|
||||
LimiterUnderApproximationNegativeError(limiter, x_split,
|
||||
interval.x1)); // Right.
|
||||
}
|
||||
|
||||
// Copy x1 values and sort them.
|
||||
RTC_CHECK_EQ(q.size(), kInterpolatedGainCurveBeyondKneePoints);
|
||||
std::vector<double> samples(kInterpolatedGainCurveBeyondKneePoints);
|
||||
for (size_t i = 0; i < kInterpolatedGainCurveBeyondKneePoints; ++i) {
|
||||
const auto interval = q.top();
|
||||
q.pop();
|
||||
samples[i] = interval.x1;
|
||||
}
|
||||
RTC_CHECK(q.empty());
|
||||
std::sort(samples.begin(), samples.end());
|
||||
|
||||
return samples;
|
||||
}
|
||||
|
||||
// Compute the parameters to over-approximate the knee region via linear
|
||||
// interpolation. Over-approximating is saturation-safe since the knee region is
|
||||
// convex.
|
||||
void PrecomputeKneeApproxParams(const LimiterDbGainCurve* limiter,
|
||||
test::InterpolatedParameters* parameters) {
|
||||
static_assert(kInterpolatedGainCurveKneePoints > 2, "");
|
||||
// Get |kInterpolatedGainCurveKneePoints| - 1 equally spaced points.
|
||||
const std::vector<double> points = test::LinSpace(
|
||||
limiter->knee_start_linear(), limiter->limiter_start_linear(),
|
||||
kInterpolatedGainCurveKneePoints - 1);
|
||||
|
||||
// Set the first two points. The second is computed to help with the beginning
|
||||
// of the knee region, which has high curvature.
|
||||
parameters->computed_approximation_params_x[0] = points[0];
|
||||
parameters->computed_approximation_params_x[1] =
|
||||
(points[0] + points[1]) / 2.0;
|
||||
// Copy the remaining points.
|
||||
std::copy(std::begin(points) + 1, std::end(points),
|
||||
std::begin(parameters->computed_approximation_params_x) + 2);
|
||||
|
||||
// Compute (m, q) pairs for each linear piece y = mx + q.
|
||||
for (size_t i = 0; i < kInterpolatedGainCurveKneePoints - 1; ++i) {
|
||||
const double x0 = parameters->computed_approximation_params_x[i];
|
||||
const double x1 = parameters->computed_approximation_params_x[i + 1];
|
||||
const double y0 = limiter->GetGainLinear(x0);
|
||||
const double y1 = limiter->GetGainLinear(x1);
|
||||
RTC_CHECK_NE(x1, x0);
|
||||
parameters->computed_approximation_params_m[i] = (y1 - y0) / (x1 - x0);
|
||||
parameters->computed_approximation_params_q[i] =
|
||||
y0 - parameters->computed_approximation_params_m[i] * x0;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the parameters to under-approximate the beyond-knee region via linear
|
||||
// interpolation and greedy sampling. Under-approximating is saturation-safe
|
||||
// since the beyond-knee region is concave.
|
||||
void PrecomputeBeyondKneeApproxParams(
|
||||
const LimiterDbGainCurve* limiter,
|
||||
test::InterpolatedParameters* parameters) {
|
||||
// Find points on which the linear pieces are tangent to the gain curve.
|
||||
const auto samples = SampleLimiterRegion(limiter);
|
||||
|
||||
// Parametrize each linear piece.
|
||||
double m, q;
|
||||
std::tie(m, q) = ComputeLinearApproximationParams(
|
||||
limiter,
|
||||
parameters
|
||||
->computed_approximation_params_x[kInterpolatedGainCurveKneePoints -
|
||||
1]);
|
||||
parameters
|
||||
->computed_approximation_params_m[kInterpolatedGainCurveKneePoints - 1] =
|
||||
m;
|
||||
parameters
|
||||
->computed_approximation_params_q[kInterpolatedGainCurveKneePoints - 1] =
|
||||
q;
|
||||
for (size_t i = 0; i < samples.size(); ++i) {
|
||||
std::tie(m, q) = ComputeLinearApproximationParams(limiter, samples[i]);
|
||||
parameters
|
||||
->computed_approximation_params_m[i +
|
||||
kInterpolatedGainCurveKneePoints] = m;
|
||||
parameters
|
||||
->computed_approximation_params_q[i +
|
||||
kInterpolatedGainCurveKneePoints] = q;
|
||||
}
|
||||
|
||||
// Find the point of intersection between adjacent linear pieces. They will be
|
||||
// used as boundaries between adjacent linear pieces.
|
||||
for (size_t i = kInterpolatedGainCurveKneePoints;
|
||||
i < kInterpolatedGainCurveKneePoints +
|
||||
kInterpolatedGainCurveBeyondKneePoints;
|
||||
++i) {
|
||||
RTC_CHECK_NE(parameters->computed_approximation_params_m[i],
|
||||
parameters->computed_approximation_params_m[i - 1]);
|
||||
parameters->computed_approximation_params_x[i] =
|
||||
( // Formula: (q0 - q1) / (m1 - m0).
|
||||
parameters->computed_approximation_params_q[i - 1] -
|
||||
parameters->computed_approximation_params_q[i]) /
|
||||
(parameters->computed_approximation_params_m[i] -
|
||||
parameters->computed_approximation_params_m[i - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace test {
|
||||
|
||||
InterpolatedParameters ComputeInterpolatedGainCurveApproximationParams() {
|
||||
InterpolatedParameters parameters;
|
||||
LimiterDbGainCurve limiter;
|
||||
parameters.computed_approximation_params_x.fill(0.0f);
|
||||
parameters.computed_approximation_params_m.fill(0.0f);
|
||||
parameters.computed_approximation_params_q.fill(0.0f);
|
||||
PrecomputeKneeApproxParams(&limiter, ¶meters);
|
||||
PrecomputeBeyondKneeApproxParams(&limiter, ¶meters);
|
||||
return parameters;
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace test {
|
||||
|
||||
// Parameters for interpolated gain curve using under-approximation to
|
||||
// avoid saturation.
|
||||
//
|
||||
// The saturation gain is defined in order to let hard-clipping occur for
|
||||
// those samples having a level that falls in the saturation region. It is an
|
||||
// upper bound of the actual gain to apply - i.e., that returned by the
|
||||
// limiter.
|
||||
|
||||
// Knee and beyond-knee regions approximation parameters.
|
||||
// The gain curve is approximated as a piece-wise linear function.
|
||||
// |approx_params_x_| are the boundaries between adjacent linear pieces,
|
||||
// |approx_params_m_| and |approx_params_q_| are the slope and the y-intercept
|
||||
// values of each piece.
|
||||
struct InterpolatedParameters {
|
||||
std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
computed_approximation_params_x;
|
||||
std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
computed_approximation_params_m;
|
||||
std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
computed_approximation_params_q;
|
||||
};
|
||||
|
||||
InterpolatedParameters ComputeInterpolatedGainCurveApproximationParams();
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
|
99
webrtc/modules/audio_processing/agc2/down_sampler.cc
Normal file
99
webrtc/modules/audio_processing/agc2/down_sampler.cc
Normal file
@ -0,0 +1,99 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/down_sampler.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "modules/audio_processing/agc2/biquad_filter.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr int kChunkSizeMs = 10;
|
||||
constexpr int kSampleRate8kHz = 8000;
|
||||
constexpr int kSampleRate16kHz = 16000;
|
||||
constexpr int kSampleRate32kHz = 32000;
|
||||
constexpr int kSampleRate48kHz = 48000;
|
||||
|
||||
// Bandlimiter coefficients computed based on that only
|
||||
// the first 40 bins of the spectrum for the downsampled
|
||||
// signal are used.
|
||||
// [B,A] = butter(2,(41/64*4000)/8000)
|
||||
const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_16kHz = {
|
||||
{0.1455f, 0.2911f, 0.1455f},
|
||||
{-0.6698f, 0.2520f}};
|
||||
|
||||
// [B,A] = butter(2,(41/64*4000)/16000)
|
||||
const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_32kHz = {
|
||||
{0.0462f, 0.0924f, 0.0462f},
|
||||
{-1.3066f, 0.4915f}};
|
||||
|
||||
// [B,A] = butter(2,(41/64*4000)/24000)
|
||||
const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_48kHz = {
|
||||
{0.0226f, 0.0452f, 0.0226f},
|
||||
{-1.5320f, 0.6224f}};
|
||||
|
||||
} // namespace
|
||||
|
||||
DownSampler::DownSampler(ApmDataDumper* data_dumper)
|
||||
: data_dumper_(data_dumper) {
|
||||
Initialize(48000);
|
||||
}
|
||||
void DownSampler::Initialize(int sample_rate_hz) {
|
||||
RTC_DCHECK(
|
||||
sample_rate_hz == kSampleRate8kHz || sample_rate_hz == kSampleRate16kHz ||
|
||||
sample_rate_hz == kSampleRate32kHz || sample_rate_hz == kSampleRate48kHz);
|
||||
|
||||
sample_rate_hz_ = sample_rate_hz;
|
||||
down_sampling_factor_ = rtc::CheckedDivExact(sample_rate_hz_, 8000);
|
||||
|
||||
/// Note that the down sampling filter is not used if the sample rate is 8
|
||||
/// kHz.
|
||||
if (sample_rate_hz_ == kSampleRate16kHz) {
|
||||
low_pass_filter_.Initialize(kLowPassFilterCoefficients_16kHz);
|
||||
} else if (sample_rate_hz_ == kSampleRate32kHz) {
|
||||
low_pass_filter_.Initialize(kLowPassFilterCoefficients_32kHz);
|
||||
} else if (sample_rate_hz_ == kSampleRate48kHz) {
|
||||
low_pass_filter_.Initialize(kLowPassFilterCoefficients_48kHz);
|
||||
}
|
||||
}
|
||||
|
||||
void DownSampler::DownSample(rtc::ArrayView<const float> in,
|
||||
rtc::ArrayView<float> out) {
|
||||
data_dumper_->DumpWav("lc_down_sampler_input", in, sample_rate_hz_, 1);
|
||||
RTC_DCHECK_EQ(sample_rate_hz_ * kChunkSizeMs / 1000, in.size());
|
||||
RTC_DCHECK_EQ(kSampleRate8kHz * kChunkSizeMs / 1000, out.size());
|
||||
const size_t kMaxNumFrames = kSampleRate48kHz * kChunkSizeMs / 1000;
|
||||
float x[kMaxNumFrames];
|
||||
|
||||
// Band-limit the signal to 4 kHz.
|
||||
if (sample_rate_hz_ != kSampleRate8kHz) {
|
||||
low_pass_filter_.Process(in, rtc::ArrayView<float>(x, in.size()));
|
||||
|
||||
// Downsample the signal.
|
||||
size_t k = 0;
|
||||
for (size_t j = 0; j < out.size(); ++j) {
|
||||
RTC_DCHECK_GT(kMaxNumFrames, k);
|
||||
out[j] = x[k];
|
||||
k += down_sampling_factor_;
|
||||
}
|
||||
} else {
|
||||
std::copy(in.data(), in.data() + in.size(), out.data());
|
||||
}
|
||||
|
||||
data_dumper_->DumpWav("lc_down_sampler_output", out, kSampleRate8kHz, 1);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
42
webrtc/modules/audio_processing/agc2/down_sampler.h
Normal file
42
webrtc/modules/audio_processing/agc2/down_sampler.h
Normal file
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/biquad_filter.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class ApmDataDumper;
|
||||
|
||||
class DownSampler {
|
||||
public:
|
||||
explicit DownSampler(ApmDataDumper* data_dumper);
|
||||
|
||||
DownSampler() = delete;
|
||||
DownSampler(const DownSampler&) = delete;
|
||||
DownSampler& operator=(const DownSampler&) = delete;
|
||||
|
||||
void Initialize(int sample_rate_hz);
|
||||
|
||||
void DownSample(rtc::ArrayView<const float> in, rtc::ArrayView<float> out);
|
||||
|
||||
private:
|
||||
ApmDataDumper* data_dumper_;
|
||||
int sample_rate_hz_;
|
||||
int down_sampling_factor_;
|
||||
BiQuadFilter low_pass_filter_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_
|
@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/fixed_digital_level_estimator.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr float kInitialFilterStateLevel = 0.f;
|
||||
|
||||
} // namespace
|
||||
|
||||
FixedDigitalLevelEstimator::FixedDigitalLevelEstimator(
|
||||
size_t sample_rate_hz,
|
||||
ApmDataDumper* apm_data_dumper)
|
||||
: apm_data_dumper_(apm_data_dumper),
|
||||
filter_state_level_(kInitialFilterStateLevel) {
|
||||
SetSampleRate(sample_rate_hz);
|
||||
CheckParameterCombination();
|
||||
RTC_DCHECK(apm_data_dumper_);
|
||||
apm_data_dumper_->DumpRaw("agc2_level_estimator_samplerate", sample_rate_hz);
|
||||
}
|
||||
|
||||
void FixedDigitalLevelEstimator::CheckParameterCombination() {
|
||||
RTC_DCHECK_GT(samples_in_frame_, 0);
|
||||
RTC_DCHECK_LE(kSubFramesInFrame, samples_in_frame_);
|
||||
RTC_DCHECK_EQ(samples_in_frame_ % kSubFramesInFrame, 0);
|
||||
RTC_DCHECK_GT(samples_in_sub_frame_, 1);
|
||||
}
|
||||
|
||||
std::array<float, kSubFramesInFrame> FixedDigitalLevelEstimator::ComputeLevel(
|
||||
const AudioFrameView<const float>& float_frame) {
|
||||
RTC_DCHECK_GT(float_frame.num_channels(), 0);
|
||||
RTC_DCHECK_EQ(float_frame.samples_per_channel(), samples_in_frame_);
|
||||
|
||||
// Compute max envelope without smoothing.
|
||||
std::array<float, kSubFramesInFrame> envelope{};
|
||||
for (size_t channel_idx = 0; channel_idx < float_frame.num_channels();
|
||||
++channel_idx) {
|
||||
const auto channel = float_frame.channel(channel_idx);
|
||||
for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
|
||||
for (size_t sample_in_sub_frame = 0;
|
||||
sample_in_sub_frame < samples_in_sub_frame_; ++sample_in_sub_frame) {
|
||||
envelope[sub_frame] =
|
||||
std::max(envelope[sub_frame],
|
||||
std::abs(channel[sub_frame * samples_in_sub_frame_ +
|
||||
sample_in_sub_frame]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure envelope increases happen one step earlier so that the
|
||||
// corresponding *gain decrease* doesn't miss a sudden signal
|
||||
// increase due to interpolation.
|
||||
for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame - 1; ++sub_frame) {
|
||||
if (envelope[sub_frame] < envelope[sub_frame + 1]) {
|
||||
envelope[sub_frame] = envelope[sub_frame + 1];
|
||||
}
|
||||
}
|
||||
|
||||
// Add attack / decay smoothing.
|
||||
for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
|
||||
const float envelope_value = envelope[sub_frame];
|
||||
if (envelope_value > filter_state_level_) {
|
||||
envelope[sub_frame] = envelope_value * (1 - kAttackFilterConstant) +
|
||||
filter_state_level_ * kAttackFilterConstant;
|
||||
} else {
|
||||
envelope[sub_frame] = envelope_value * (1 - kDecayFilterConstant) +
|
||||
filter_state_level_ * kDecayFilterConstant;
|
||||
}
|
||||
filter_state_level_ = envelope[sub_frame];
|
||||
|
||||
// Dump data for debug.
|
||||
RTC_DCHECK(apm_data_dumper_);
|
||||
const auto channel = float_frame.channel(0);
|
||||
apm_data_dumper_->DumpRaw("agc2_level_estimator_samples",
|
||||
samples_in_sub_frame_,
|
||||
&channel[sub_frame * samples_in_sub_frame_]);
|
||||
apm_data_dumper_->DumpRaw("agc2_level_estimator_level",
|
||||
envelope[sub_frame]);
|
||||
}
|
||||
|
||||
return envelope;
|
||||
}
|
||||
|
||||
void FixedDigitalLevelEstimator::SetSampleRate(size_t sample_rate_hz) {
|
||||
samples_in_frame_ = rtc::CheckedDivExact(sample_rate_hz * kFrameDurationMs,
|
||||
static_cast<size_t>(1000));
|
||||
samples_in_sub_frame_ =
|
||||
rtc::CheckedDivExact(samples_in_frame_, kSubFramesInFrame);
|
||||
CheckParameterCombination();
|
||||
}
|
||||
|
||||
void FixedDigitalLevelEstimator::Reset() {
|
||||
filter_state_level_ = kInitialFilterStateLevel;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
@ -0,0 +1,65 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
#include "rtc_base/constructor_magic.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class ApmDataDumper;
|
||||
// Produces a smooth signal level estimate from an input audio
|
||||
// stream. The estimate smoothing is done through exponential
|
||||
// filtering.
|
||||
class FixedDigitalLevelEstimator {
|
||||
public:
|
||||
// Sample rates are allowed if the number of samples in a frame
|
||||
// (sample_rate_hz * kFrameDurationMs / 1000) is divisible by
|
||||
// kSubFramesInSample. For kFrameDurationMs=10 and
|
||||
// kSubFramesInSample=20, this means that sample_rate_hz has to be
|
||||
// divisible by 2000.
|
||||
FixedDigitalLevelEstimator(size_t sample_rate_hz,
|
||||
ApmDataDumper* apm_data_dumper);
|
||||
|
||||
// The input is assumed to be in FloatS16 format. Scaled input will
|
||||
// produce similarly scaled output. A frame of with kFrameDurationMs
|
||||
// ms of audio produces a level estimates in the same scale. The
|
||||
// level estimate contains kSubFramesInFrame values.
|
||||
std::array<float, kSubFramesInFrame> ComputeLevel(
|
||||
const AudioFrameView<const float>& float_frame);
|
||||
|
||||
// Rate may be changed at any time (but not concurrently) from the
|
||||
// value passed to the constructor. The class is not thread safe.
|
||||
void SetSampleRate(size_t sample_rate_hz);
|
||||
|
||||
// Resets the level estimator internal state.
|
||||
void Reset();
|
||||
|
||||
float LastAudioLevel() const { return filter_state_level_; }
|
||||
|
||||
private:
|
||||
void CheckParameterCombination();
|
||||
|
||||
ApmDataDumper* const apm_data_dumper_ = nullptr;
|
||||
float filter_state_level_;
|
||||
size_t samples_in_frame_;
|
||||
size_t samples_in_sub_frame_;
|
||||
|
||||
RTC_DISALLOW_COPY_AND_ASSIGN(FixedDigitalLevelEstimator);
|
||||
};
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
|
101
webrtc/modules/audio_processing/agc2/fixed_gain_controller.cc
Normal file
101
webrtc/modules/audio_processing/agc2/fixed_gain_controller.cc
Normal file
@ -0,0 +1,101 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/fixed_gain_controller.h"
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
// Returns true when the gain factor is so close to 1 that it would
|
||||
// not affect int16 samples.
|
||||
bool CloseToOne(float gain_factor) {
|
||||
return 1.f - 1.f / kMaxFloatS16Value <= gain_factor &&
|
||||
gain_factor <= 1.f + 1.f / kMaxFloatS16Value;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FixedGainController::FixedGainController(ApmDataDumper* apm_data_dumper)
|
||||
: FixedGainController(apm_data_dumper, "Agc2") {}
|
||||
|
||||
FixedGainController::FixedGainController(ApmDataDumper* apm_data_dumper,
|
||||
std::string histogram_name_prefix)
|
||||
: apm_data_dumper_(apm_data_dumper),
|
||||
limiter_(48000, apm_data_dumper_, histogram_name_prefix) {
|
||||
// Do update histograms.xml when adding name prefixes.
|
||||
RTC_DCHECK(histogram_name_prefix == "" || histogram_name_prefix == "Test" ||
|
||||
histogram_name_prefix == "AudioMixer" ||
|
||||
histogram_name_prefix == "Agc2");
|
||||
}
|
||||
|
||||
void FixedGainController::SetGain(float gain_to_apply_db) {
|
||||
// Changes in gain_to_apply_ cause discontinuities. We assume
|
||||
// gain_to_apply_ is set in the beginning of the call. If it is
|
||||
// frequently changed, we should add interpolation between the
|
||||
// values.
|
||||
// The gain
|
||||
RTC_DCHECK_LE(-50.f, gain_to_apply_db);
|
||||
RTC_DCHECK_LE(gain_to_apply_db, 50.f);
|
||||
const float previous_applied_gained = gain_to_apply_;
|
||||
gain_to_apply_ = DbToRatio(gain_to_apply_db);
|
||||
RTC_DCHECK_LT(0.f, gain_to_apply_);
|
||||
RTC_DLOG(LS_INFO) << "Gain to apply: " << gain_to_apply_db << " db.";
|
||||
// Reset the gain curve applier to quickly react on abrupt level changes
|
||||
// caused by large changes of the applied gain.
|
||||
if (previous_applied_gained != gain_to_apply_) {
|
||||
limiter_.Reset();
|
||||
}
|
||||
}
|
||||
|
||||
void FixedGainController::SetSampleRate(size_t sample_rate_hz) {
|
||||
limiter_.SetSampleRate(sample_rate_hz);
|
||||
}
|
||||
|
||||
void FixedGainController::Process(AudioFrameView<float> signal) {
|
||||
// Apply fixed digital gain. One of the
|
||||
// planned usages of the FGC is to only use the limiter. In that
|
||||
// case, the gain would be 1.0. Not doing the multiplications speeds
|
||||
// it up considerably. Hence the check.
|
||||
if (!CloseToOne(gain_to_apply_)) {
|
||||
for (size_t k = 0; k < signal.num_channels(); ++k) {
|
||||
rtc::ArrayView<float> channel_view = signal.channel(k);
|
||||
for (auto& sample : channel_view) {
|
||||
sample *= gain_to_apply_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the limiter.
|
||||
limiter_.Process(signal);
|
||||
|
||||
// Dump data for debug.
|
||||
const auto channel_view = signal.channel(0);
|
||||
apm_data_dumper_->DumpRaw("agc2_fixed_digital_gain_curve_applier",
|
||||
channel_view.size(), channel_view.data());
|
||||
// Hard-clipping.
|
||||
for (size_t k = 0; k < signal.num_channels(); ++k) {
|
||||
rtc::ArrayView<float> channel_view = signal.channel(k);
|
||||
for (auto& sample : channel_view) {
|
||||
sample = rtc::SafeClamp(sample, kMinFloatS16Value, kMaxFloatS16Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float FixedGainController::LastAudioLevel() const {
|
||||
return limiter_.LastAudioLevel();
|
||||
}
|
||||
} // namespace webrtc
|
102
webrtc/modules/audio_processing/agc2/gain_applier.cc
Normal file
102
webrtc/modules/audio_processing/agc2/gain_applier.cc
Normal file
@ -0,0 +1,102 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/gain_applier.h"
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
// Returns true when the gain factor is so close to 1 that it would
|
||||
// not affect int16 samples.
|
||||
bool GainCloseToOne(float gain_factor) {
|
||||
return 1.f - 1.f / kMaxFloatS16Value <= gain_factor &&
|
||||
gain_factor <= 1.f + 1.f / kMaxFloatS16Value;
|
||||
}
|
||||
|
||||
void ClipSignal(AudioFrameView<float> signal) {
|
||||
for (size_t k = 0; k < signal.num_channels(); ++k) {
|
||||
rtc::ArrayView<float> channel_view = signal.channel(k);
|
||||
for (auto& sample : channel_view) {
|
||||
sample = rtc::SafeClamp(sample, kMinFloatS16Value, kMaxFloatS16Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyGainWithRamping(float last_gain_linear,
|
||||
float gain_at_end_of_frame_linear,
|
||||
float inverse_samples_per_channel,
|
||||
AudioFrameView<float> float_frame) {
|
||||
// Do not modify the signal.
|
||||
if (last_gain_linear == gain_at_end_of_frame_linear &&
|
||||
GainCloseToOne(gain_at_end_of_frame_linear)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Gain is constant and different from 1.
|
||||
if (last_gain_linear == gain_at_end_of_frame_linear) {
|
||||
for (size_t k = 0; k < float_frame.num_channels(); ++k) {
|
||||
rtc::ArrayView<float> channel_view = float_frame.channel(k);
|
||||
for (auto& sample : channel_view) {
|
||||
sample *= gain_at_end_of_frame_linear;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// The gain changes. We have to change slowly to avoid discontinuities.
|
||||
const float increment = (gain_at_end_of_frame_linear - last_gain_linear) *
|
||||
inverse_samples_per_channel;
|
||||
float gain = last_gain_linear;
|
||||
for (size_t i = 0; i < float_frame.samples_per_channel(); ++i) {
|
||||
for (size_t ch = 0; ch < float_frame.num_channels(); ++ch) {
|
||||
float_frame.channel(ch)[i] *= gain;
|
||||
}
|
||||
gain += increment;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GainApplier::GainApplier(bool hard_clip_samples, float initial_gain_factor)
|
||||
: hard_clip_samples_(hard_clip_samples),
|
||||
last_gain_factor_(initial_gain_factor),
|
||||
current_gain_factor_(initial_gain_factor) {}
|
||||
|
||||
void GainApplier::ApplyGain(AudioFrameView<float> signal) {
|
||||
if (static_cast<int>(signal.samples_per_channel()) != samples_per_channel_) {
|
||||
Initialize(signal.samples_per_channel());
|
||||
}
|
||||
|
||||
ApplyGainWithRamping(last_gain_factor_, current_gain_factor_,
|
||||
inverse_samples_per_channel_, signal);
|
||||
|
||||
last_gain_factor_ = current_gain_factor_;
|
||||
|
||||
if (hard_clip_samples_) {
|
||||
ClipSignal(signal);
|
||||
}
|
||||
}
|
||||
|
||||
void GainApplier::SetGainFactor(float gain_factor) {
|
||||
RTC_DCHECK_GT(gain_factor, 0.f);
|
||||
current_gain_factor_ = gain_factor;
|
||||
}
|
||||
|
||||
void GainApplier::Initialize(size_t samples_per_channel) {
|
||||
RTC_DCHECK_GT(samples_per_channel, 0);
|
||||
samples_per_channel_ = static_cast<int>(samples_per_channel);
|
||||
inverse_samples_per_channel_ = 1.f / samples_per_channel_;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
44
webrtc/modules/audio_processing/agc2/gain_applier.h
Normal file
44
webrtc/modules/audio_processing/agc2/gain_applier.h
Normal file
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
class GainApplier {
|
||||
public:
|
||||
GainApplier(bool hard_clip_samples, float initial_gain_factor);
|
||||
|
||||
void ApplyGain(AudioFrameView<float> signal);
|
||||
void SetGainFactor(float gain_factor);
|
||||
float GetGainFactor() const { return current_gain_factor_; }
|
||||
|
||||
private:
|
||||
void Initialize(size_t samples_per_channel);
|
||||
|
||||
// Whether to clip samples after gain is applied. If 'true', result
|
||||
// will fit in FloatS16 range.
|
||||
const bool hard_clip_samples_;
|
||||
float last_gain_factor_;
|
||||
|
||||
// If this value is not equal to 'last_gain_factor', gain will be
|
||||
// ramped from 'last_gain_factor_' to this value during the next
|
||||
// 'ApplyGain'.
|
||||
float current_gain_factor_;
|
||||
int samples_per_channel_ = -1;
|
||||
float inverse_samples_per_channel_ = -1.f;
|
||||
};
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
|
195
webrtc/modules/audio_processing/agc2/interpolated_gain_curve.cc
Normal file
195
webrtc/modules/audio_processing/agc2/interpolated_gain_curve.cc
Normal file
@ -0,0 +1,195 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/interpolated_gain_curve.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
InterpolatedGainCurve::approximation_params_x_;
|
||||
|
||||
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
InterpolatedGainCurve::approximation_params_m_;
|
||||
|
||||
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
InterpolatedGainCurve::approximation_params_q_;
|
||||
|
||||
InterpolatedGainCurve::InterpolatedGainCurve(ApmDataDumper* apm_data_dumper,
|
||||
std::string histogram_name_prefix)
|
||||
: region_logger_("WebRTC.Audio." + histogram_name_prefix +
|
||||
".FixedDigitalGainCurveRegion.Identity",
|
||||
"WebRTC.Audio." + histogram_name_prefix +
|
||||
".FixedDigitalGainCurveRegion.Knee",
|
||||
"WebRTC.Audio." + histogram_name_prefix +
|
||||
".FixedDigitalGainCurveRegion.Limiter",
|
||||
"WebRTC.Audio." + histogram_name_prefix +
|
||||
".FixedDigitalGainCurveRegion.Saturation"),
|
||||
apm_data_dumper_(apm_data_dumper) {}
|
||||
|
||||
InterpolatedGainCurve::~InterpolatedGainCurve() {
|
||||
if (stats_.available) {
|
||||
RTC_DCHECK(apm_data_dumper_);
|
||||
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_identity",
|
||||
stats_.look_ups_identity_region);
|
||||
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_knee",
|
||||
stats_.look_ups_knee_region);
|
||||
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_limiter",
|
||||
stats_.look_ups_limiter_region);
|
||||
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_saturation",
|
||||
stats_.look_ups_saturation_region);
|
||||
region_logger_.LogRegionStats(stats_);
|
||||
}
|
||||
}
|
||||
|
||||
InterpolatedGainCurve::RegionLogger::RegionLogger(
|
||||
std::string identity_histogram_name,
|
||||
std::string knee_histogram_name,
|
||||
std::string limiter_histogram_name,
|
||||
std::string saturation_histogram_name)
|
||||
: identity_histogram(
|
||||
metrics::HistogramFactoryGetCounts(identity_histogram_name,
|
||||
1,
|
||||
10000,
|
||||
50)),
|
||||
knee_histogram(metrics::HistogramFactoryGetCounts(knee_histogram_name,
|
||||
1,
|
||||
10000,
|
||||
50)),
|
||||
limiter_histogram(
|
||||
metrics::HistogramFactoryGetCounts(limiter_histogram_name,
|
||||
1,
|
||||
10000,
|
||||
50)),
|
||||
saturation_histogram(
|
||||
metrics::HistogramFactoryGetCounts(saturation_histogram_name,
|
||||
1,
|
||||
10000,
|
||||
50)) {}
|
||||
|
||||
InterpolatedGainCurve::RegionLogger::~RegionLogger() = default;
|
||||
|
||||
void InterpolatedGainCurve::RegionLogger::LogRegionStats(
|
||||
const InterpolatedGainCurve::Stats& stats) const {
|
||||
using Region = InterpolatedGainCurve::GainCurveRegion;
|
||||
const int duration_s =
|
||||
stats.region_duration_frames / (1000 / kFrameDurationMs);
|
||||
|
||||
switch (stats.region) {
|
||||
case Region::kIdentity: {
|
||||
if (identity_histogram) {
|
||||
metrics::HistogramAdd(identity_histogram, duration_s);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Region::kKnee: {
|
||||
if (knee_histogram) {
|
||||
metrics::HistogramAdd(knee_histogram, duration_s);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Region::kLimiter: {
|
||||
if (limiter_histogram) {
|
||||
metrics::HistogramAdd(limiter_histogram, duration_s);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Region::kSaturation: {
|
||||
if (saturation_histogram) {
|
||||
metrics::HistogramAdd(saturation_histogram, duration_s);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
RTC_NOTREACHED();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InterpolatedGainCurve::UpdateStats(float input_level) const {
|
||||
stats_.available = true;
|
||||
|
||||
GainCurveRegion region;
|
||||
|
||||
if (input_level < approximation_params_x_[0]) {
|
||||
stats_.look_ups_identity_region++;
|
||||
region = GainCurveRegion::kIdentity;
|
||||
} else if (input_level <
|
||||
approximation_params_x_[kInterpolatedGainCurveKneePoints - 1]) {
|
||||
stats_.look_ups_knee_region++;
|
||||
region = GainCurveRegion::kKnee;
|
||||
} else if (input_level < kMaxInputLevelLinear) {
|
||||
stats_.look_ups_limiter_region++;
|
||||
region = GainCurveRegion::kLimiter;
|
||||
} else {
|
||||
stats_.look_ups_saturation_region++;
|
||||
region = GainCurveRegion::kSaturation;
|
||||
}
|
||||
|
||||
if (region == stats_.region) {
|
||||
++stats_.region_duration_frames;
|
||||
} else {
|
||||
region_logger_.LogRegionStats(stats_);
|
||||
|
||||
stats_.region_duration_frames = 0;
|
||||
stats_.region = region;
|
||||
}
|
||||
}
|
||||
|
||||
// Looks up a gain to apply given a non-negative input level.
|
||||
// The cost of this operation depends on the region in which |input_level|
|
||||
// falls.
|
||||
// For the identity and the saturation regions the cost is O(1).
|
||||
// For the other regions, namely knee and limiter, the cost is
|
||||
// O(2 + log2(|LightkInterpolatedGainCurveTotalPoints|), plus O(1) for the
|
||||
// linear interpolation (one product and one sum).
|
||||
float InterpolatedGainCurve::LookUpGainToApply(float input_level) const {
|
||||
UpdateStats(input_level);
|
||||
|
||||
if (input_level <= approximation_params_x_[0]) {
|
||||
// Identity region.
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
if (input_level >= kMaxInputLevelLinear) {
|
||||
// Saturating lower bound. The saturing samples exactly hit the clipping
|
||||
// level. This method achieves has the lowest harmonic distorsion, but it
|
||||
// may reduce the amplitude of the non-saturating samples too much.
|
||||
return 32768.f / input_level;
|
||||
}
|
||||
|
||||
// Knee and limiter regions; find the linear piece index. Spelling
|
||||
// out the complete type was the only way to silence both the clang
|
||||
// plugin and the windows compilers.
|
||||
std::array<float, kInterpolatedGainCurveTotalPoints>::const_iterator it =
|
||||
std::lower_bound(approximation_params_x_.begin(),
|
||||
approximation_params_x_.end(), input_level);
|
||||
const size_t index = std::distance(approximation_params_x_.begin(), it) - 1;
|
||||
RTC_DCHECK_LE(0, index);
|
||||
RTC_DCHECK_LT(index, approximation_params_m_.size());
|
||||
RTC_DCHECK_LE(approximation_params_x_[index], input_level);
|
||||
if (index < approximation_params_m_.size() - 1) {
|
||||
RTC_DCHECK_LE(input_level, approximation_params_x_[index + 1]);
|
||||
}
|
||||
|
||||
// Piece-wise linear interploation.
|
||||
const float gain = approximation_params_m_[index] * input_level +
|
||||
approximation_params_q_[index];
|
||||
RTC_DCHECK_LE(0.f, gain);
|
||||
return gain;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
152
webrtc/modules/audio_processing/agc2/interpolated_gain_curve.h
Normal file
152
webrtc/modules/audio_processing/agc2/interpolated_gain_curve.h
Normal file
@ -0,0 +1,152 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "rtc_base/constructor_magic.h"
|
||||
#include "rtc_base/gtest_prod_util.h"
|
||||
#include "system_wrappers/include/metrics.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class ApmDataDumper;
|
||||
|
||||
constexpr float kInputLevelScalingFactor = 32768.0f;
|
||||
|
||||
// Defined as DbfsToLinear(kLimiterMaxInputLevelDbFs)
|
||||
constexpr float kMaxInputLevelLinear = static_cast<float>(36766.300710566735);
|
||||
|
||||
// Interpolated gain curve using under-approximation to avoid saturation.
|
||||
//
|
||||
// The goal of this class is allowing fast look ups to get an accurate
|
||||
// estimates of the gain to apply given an estimated input level.
|
||||
class InterpolatedGainCurve {
|
||||
public:
|
||||
enum class GainCurveRegion {
|
||||
kIdentity = 0,
|
||||
kKnee = 1,
|
||||
kLimiter = 2,
|
||||
kSaturation = 3
|
||||
};
|
||||
|
||||
struct Stats {
|
||||
// Region in which the output level equals the input one.
|
||||
size_t look_ups_identity_region = 0;
|
||||
// Smoothing between the identity and the limiter regions.
|
||||
size_t look_ups_knee_region = 0;
|
||||
// Limiter region in which the output and input levels are linearly related.
|
||||
size_t look_ups_limiter_region = 0;
|
||||
// Region in which saturation may occur since the input level is beyond the
|
||||
// maximum expected by the limiter.
|
||||
size_t look_ups_saturation_region = 0;
|
||||
// True if stats have been populated.
|
||||
bool available = false;
|
||||
|
||||
// The current region, and for how many frames the level has been
|
||||
// in that region.
|
||||
GainCurveRegion region = GainCurveRegion::kIdentity;
|
||||
int64_t region_duration_frames = 0;
|
||||
};
|
||||
|
||||
InterpolatedGainCurve(ApmDataDumper* apm_data_dumper,
|
||||
std::string histogram_name_prefix);
|
||||
~InterpolatedGainCurve();
|
||||
|
||||
Stats get_stats() const { return stats_; }
|
||||
|
||||
// Given a non-negative input level (linear scale), a scalar factor to apply
|
||||
// to a sub-frame is returned.
|
||||
// Levels above kLimiterMaxInputLevelDbFs will be reduced to 0 dBFS
|
||||
// after applying this gain
|
||||
float LookUpGainToApply(float input_level) const;
|
||||
|
||||
private:
|
||||
// For comparing 'approximation_params_*_' with ones computed by
|
||||
// ComputeInterpolatedGainCurve.
|
||||
FRIEND_TEST_ALL_PREFIXES(AutomaticGainController2InterpolatedGainCurve,
|
||||
CheckApproximationParams);
|
||||
|
||||
struct RegionLogger {
|
||||
metrics::Histogram* identity_histogram;
|
||||
metrics::Histogram* knee_histogram;
|
||||
metrics::Histogram* limiter_histogram;
|
||||
metrics::Histogram* saturation_histogram;
|
||||
|
||||
RegionLogger(std::string identity_histogram_name,
|
||||
std::string knee_histogram_name,
|
||||
std::string limiter_histogram_name,
|
||||
std::string saturation_histogram_name);
|
||||
|
||||
~RegionLogger();
|
||||
|
||||
void LogRegionStats(const InterpolatedGainCurve::Stats& stats) const;
|
||||
} region_logger_;
|
||||
|
||||
void UpdateStats(float input_level) const;
|
||||
|
||||
ApmDataDumper* const apm_data_dumper_;
|
||||
|
||||
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
approximation_params_x_ = {
|
||||
{30057.296875, 30148.986328125, 30240.67578125, 30424.052734375,
|
||||
30607.4296875, 30790.806640625, 30974.18359375, 31157.560546875,
|
||||
31340.939453125, 31524.31640625, 31707.693359375, 31891.0703125,
|
||||
32074.447265625, 32257.82421875, 32441.201171875, 32624.580078125,
|
||||
32807.95703125, 32991.33203125, 33174.7109375, 33358.08984375,
|
||||
33541.46484375, 33724.84375, 33819.53515625, 34009.5390625,
|
||||
34200.05859375, 34389.81640625, 34674.48828125, 35054.375,
|
||||
35434.86328125, 35814.81640625, 36195.16796875, 36575.03125}};
|
||||
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
approximation_params_m_ = {
|
||||
{-3.515235675877192989e-07, -1.050251626111275982e-06,
|
||||
-2.085213736791047268e-06, -3.443004743530764244e-06,
|
||||
-4.773849468620028347e-06, -6.077375928725814447e-06,
|
||||
-7.353257842623861507e-06, -8.601219633419532329e-06,
|
||||
-9.821013009059242904e-06, -1.101243378798244521e-05,
|
||||
-1.217532644659513608e-05, -1.330956911260727793e-05,
|
||||
-1.441507538402220234e-05, -1.549179251014720649e-05,
|
||||
-1.653970684856176376e-05, -1.755882840370759368e-05,
|
||||
-1.854918446042574942e-05, -1.951086778717581183e-05,
|
||||
-2.044398024736437947e-05, -2.1348627342376858e-05,
|
||||
-2.222496914328075945e-05, -2.265374678245279938e-05,
|
||||
-2.242570917587727308e-05, -2.220122041762806475e-05,
|
||||
-2.19802095671184361e-05, -2.176260204578284174e-05,
|
||||
-2.133731686626560986e-05, -2.092481918225530535e-05,
|
||||
-2.052459603874012828e-05, -2.013615448959171772e-05,
|
||||
-1.975903069251216948e-05, -1.939277899509761482e-05}};
|
||||
|
||||
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
approximation_params_q_ = {
|
||||
{1.010565876960754395, 1.031631827354431152, 1.062929749488830566,
|
||||
1.104239225387573242, 1.144973039627075195, 1.185109615325927734,
|
||||
1.224629044532775879, 1.263512492179870605, 1.301741957664489746,
|
||||
1.339300632476806641, 1.376173257827758789, 1.412345528602600098,
|
||||
1.447803974151611328, 1.482536554336547852, 1.516532182693481445,
|
||||
1.549780607223510742, 1.582272171974182129, 1.613999366760253906,
|
||||
1.644955039024353027, 1.675132393836975098, 1.704526185989379883,
|
||||
1.718986630439758301, 1.711274504661560059, 1.703639745712280273,
|
||||
1.696081161499023438, 1.688597679138183594, 1.673851132392883301,
|
||||
1.659391283988952637, 1.645209431648254395, 1.631297469139099121,
|
||||
1.617647409439086914, 1.604251742362976074}};
|
||||
|
||||
// Stats.
|
||||
mutable Stats stats_;
|
||||
|
||||
RTC_DISALLOW_COPY_AND_ASSIGN(InterpolatedGainCurve);
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
|
150
webrtc/modules/audio_processing/agc2/limiter.cc
Normal file
150
webrtc/modules/audio_processing/agc2/limiter.cc
Normal file
@ -0,0 +1,150 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/limiter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
// This constant affects the way scaling factors are interpolated for the first
|
||||
// sub-frame of a frame. Only in the case in which the first sub-frame has an
|
||||
// estimated level which is greater than the that of the previous analyzed
|
||||
// sub-frame, linear interpolation is replaced with a power function which
|
||||
// reduces the chances of over-shooting (and hence saturation), however reducing
|
||||
// the fixed gain effectiveness.
|
||||
constexpr float kAttackFirstSubframeInterpolationPower = 8.f;
|
||||
|
||||
void InterpolateFirstSubframe(float last_factor,
|
||||
float current_factor,
|
||||
rtc::ArrayView<float> subframe) {
|
||||
const auto n = subframe.size();
|
||||
constexpr auto p = kAttackFirstSubframeInterpolationPower;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
subframe[i] = std::pow(1.f - i / n, p) * (last_factor - current_factor) +
|
||||
current_factor;
|
||||
}
|
||||
}
|
||||
|
||||
void ComputePerSampleSubframeFactors(
|
||||
const std::array<float, kSubFramesInFrame + 1>& scaling_factors,
|
||||
size_t samples_per_channel,
|
||||
rtc::ArrayView<float> per_sample_scaling_factors) {
|
||||
const size_t num_subframes = scaling_factors.size() - 1;
|
||||
const size_t subframe_size =
|
||||
rtc::CheckedDivExact(samples_per_channel, num_subframes);
|
||||
|
||||
// Handle first sub-frame differently in case of attack.
|
||||
const bool is_attack = scaling_factors[0] > scaling_factors[1];
|
||||
if (is_attack) {
|
||||
InterpolateFirstSubframe(
|
||||
scaling_factors[0], scaling_factors[1],
|
||||
rtc::ArrayView<float>(
|
||||
per_sample_scaling_factors.subview(0, subframe_size)));
|
||||
}
|
||||
|
||||
for (size_t i = is_attack ? 1 : 0; i < num_subframes; ++i) {
|
||||
const size_t subframe_start = i * subframe_size;
|
||||
const float scaling_start = scaling_factors[i];
|
||||
const float scaling_end = scaling_factors[i + 1];
|
||||
const float scaling_diff = (scaling_end - scaling_start) / subframe_size;
|
||||
for (size_t j = 0; j < subframe_size; ++j) {
|
||||
per_sample_scaling_factors[subframe_start + j] =
|
||||
scaling_start + scaling_diff * j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScaleSamples(rtc::ArrayView<const float> per_sample_scaling_factors,
|
||||
AudioFrameView<float> signal) {
|
||||
const size_t samples_per_channel = signal.samples_per_channel();
|
||||
RTC_DCHECK_EQ(samples_per_channel, per_sample_scaling_factors.size());
|
||||
for (size_t i = 0; i < signal.num_channels(); ++i) {
|
||||
auto channel = signal.channel(i);
|
||||
for (size_t j = 0; j < samples_per_channel; ++j) {
|
||||
channel[j] = rtc::SafeClamp(channel[j] * per_sample_scaling_factors[j],
|
||||
kMinFloatS16Value, kMaxFloatS16Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckLimiterSampleRate(size_t sample_rate_hz) {
|
||||
// Check that per_sample_scaling_factors_ is large enough.
|
||||
RTC_DCHECK_LE(sample_rate_hz,
|
||||
kMaximalNumberOfSamplesPerChannel * 1000 / kFrameDurationMs);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Limiter::Limiter(size_t sample_rate_hz,
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
std::string histogram_name)
|
||||
: interp_gain_curve_(apm_data_dumper, histogram_name),
|
||||
level_estimator_(sample_rate_hz, apm_data_dumper),
|
||||
apm_data_dumper_(apm_data_dumper) {
|
||||
CheckLimiterSampleRate(sample_rate_hz);
|
||||
}
|
||||
|
||||
Limiter::~Limiter() = default;
|
||||
|
||||
void Limiter::Process(AudioFrameView<float> signal) {
|
||||
const auto level_estimate = level_estimator_.ComputeLevel(signal);
|
||||
|
||||
RTC_DCHECK_EQ(level_estimate.size() + 1, scaling_factors_.size());
|
||||
scaling_factors_[0] = last_scaling_factor_;
|
||||
std::transform(level_estimate.begin(), level_estimate.end(),
|
||||
scaling_factors_.begin() + 1, [this](float x) {
|
||||
return interp_gain_curve_.LookUpGainToApply(x);
|
||||
});
|
||||
|
||||
const size_t samples_per_channel = signal.samples_per_channel();
|
||||
RTC_DCHECK_LE(samples_per_channel, kMaximalNumberOfSamplesPerChannel);
|
||||
|
||||
auto per_sample_scaling_factors = rtc::ArrayView<float>(
|
||||
&per_sample_scaling_factors_[0], samples_per_channel);
|
||||
ComputePerSampleSubframeFactors(scaling_factors_, samples_per_channel,
|
||||
per_sample_scaling_factors);
|
||||
ScaleSamples(per_sample_scaling_factors, signal);
|
||||
|
||||
last_scaling_factor_ = scaling_factors_.back();
|
||||
|
||||
// Dump data for debug.
|
||||
apm_data_dumper_->DumpRaw("agc2_gain_curve_applier_scaling_factors",
|
||||
samples_per_channel,
|
||||
per_sample_scaling_factors_.data());
|
||||
}
|
||||
|
||||
InterpolatedGainCurve::Stats Limiter::GetGainCurveStats() const {
|
||||
return interp_gain_curve_.get_stats();
|
||||
}
|
||||
|
||||
void Limiter::SetSampleRate(size_t sample_rate_hz) {
|
||||
CheckLimiterSampleRate(sample_rate_hz);
|
||||
level_estimator_.SetSampleRate(sample_rate_hz);
|
||||
}
|
||||
|
||||
void Limiter::Reset() {
|
||||
level_estimator_.Reset();
|
||||
}
|
||||
|
||||
float Limiter::LastAudioLevel() const {
|
||||
return level_estimator_.LastAudioLevel();
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
64
webrtc/modules/audio_processing/agc2/limiter.h
Normal file
64
webrtc/modules/audio_processing/agc2/limiter.h
Normal file
@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/fixed_digital_level_estimator.h"
|
||||
#include "modules/audio_processing/agc2/interpolated_gain_curve.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
#include "rtc_base/constructor_magic.h"
|
||||
|
||||
namespace webrtc {
|
||||
class ApmDataDumper;
|
||||
|
||||
class Limiter {
|
||||
public:
|
||||
Limiter(size_t sample_rate_hz,
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
std::string histogram_name_prefix);
|
||||
Limiter(const Limiter& limiter) = delete;
|
||||
Limiter& operator=(const Limiter& limiter) = delete;
|
||||
~Limiter();
|
||||
|
||||
// Applies limiter and hard-clipping to |signal|.
|
||||
void Process(AudioFrameView<float> signal);
|
||||
InterpolatedGainCurve::Stats GetGainCurveStats() const;
|
||||
|
||||
// Supported rates must be
|
||||
// * supported by FixedDigitalLevelEstimator
|
||||
// * below kMaximalNumberOfSamplesPerChannel*1000/kFrameDurationMs
|
||||
// so that samples_per_channel fit in the
|
||||
// per_sample_scaling_factors_ array.
|
||||
void SetSampleRate(size_t sample_rate_hz);
|
||||
|
||||
// Resets the internal state.
|
||||
void Reset();
|
||||
|
||||
float LastAudioLevel() const;
|
||||
|
||||
private:
|
||||
const InterpolatedGainCurve interp_gain_curve_;
|
||||
FixedDigitalLevelEstimator level_estimator_;
|
||||
ApmDataDumper* const apm_data_dumper_ = nullptr;
|
||||
|
||||
// Work array containing the sub-frame scaling factors to be interpolated.
|
||||
std::array<float, kSubFramesInFrame + 1> scaling_factors_ = {};
|
||||
std::array<float, kMaximalNumberOfSamplesPerChannel>
|
||||
per_sample_scaling_factors_ = {};
|
||||
float last_scaling_factor_ = 1.f;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
|
138
webrtc/modules/audio_processing/agc2/limiter_db_gain_curve.cc
Normal file
138
webrtc/modules/audio_processing/agc2/limiter_db_gain_curve.cc
Normal file
@ -0,0 +1,138 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/limiter_db_gain_curve.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
double ComputeKneeStart(double max_input_level_db,
|
||||
double knee_smoothness_db,
|
||||
double compression_ratio) {
|
||||
RTC_CHECK_LT((compression_ratio - 1.0) * knee_smoothness_db /
|
||||
(2.0 * compression_ratio),
|
||||
max_input_level_db);
|
||||
return -knee_smoothness_db / 2.0 -
|
||||
max_input_level_db / (compression_ratio - 1.0);
|
||||
}
|
||||
|
||||
std::array<double, 3> ComputeKneeRegionPolynomial(double knee_start_dbfs,
|
||||
double knee_smoothness_db,
|
||||
double compression_ratio) {
|
||||
const double a = (1.0 - compression_ratio) /
|
||||
(2.0 * knee_smoothness_db * compression_ratio);
|
||||
const double b = 1.0 - 2.0 * a * knee_start_dbfs;
|
||||
const double c = a * knee_start_dbfs * knee_start_dbfs;
|
||||
return {{a, b, c}};
|
||||
}
|
||||
|
||||
double ComputeLimiterD1(double max_input_level_db, double compression_ratio) {
|
||||
return (std::pow(10.0, -max_input_level_db / (20.0 * compression_ratio)) *
|
||||
(1.0 - compression_ratio) / compression_ratio) /
|
||||
kMaxAbsFloatS16Value;
|
||||
}
|
||||
|
||||
constexpr double ComputeLimiterD2(double compression_ratio) {
|
||||
return (1.0 - 2.0 * compression_ratio) / compression_ratio;
|
||||
}
|
||||
|
||||
double ComputeLimiterI2(double max_input_level_db,
|
||||
double compression_ratio,
|
||||
double gain_curve_limiter_i1) {
|
||||
RTC_CHECK_NE(gain_curve_limiter_i1, 0.f);
|
||||
return std::pow(10.0, -max_input_level_db / (20.0 * compression_ratio)) /
|
||||
gain_curve_limiter_i1 /
|
||||
std::pow(kMaxAbsFloatS16Value, gain_curve_limiter_i1 - 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LimiterDbGainCurve::LimiterDbGainCurve()
|
||||
: max_input_level_linear_(DbfsToFloatS16(max_input_level_db_)),
|
||||
knee_start_dbfs_(ComputeKneeStart(max_input_level_db_,
|
||||
knee_smoothness_db_,
|
||||
compression_ratio_)),
|
||||
knee_start_linear_(DbfsToFloatS16(knee_start_dbfs_)),
|
||||
limiter_start_dbfs_(knee_start_dbfs_ + knee_smoothness_db_),
|
||||
limiter_start_linear_(DbfsToFloatS16(limiter_start_dbfs_)),
|
||||
knee_region_polynomial_(ComputeKneeRegionPolynomial(knee_start_dbfs_,
|
||||
knee_smoothness_db_,
|
||||
compression_ratio_)),
|
||||
gain_curve_limiter_d1_(
|
||||
ComputeLimiterD1(max_input_level_db_, compression_ratio_)),
|
||||
gain_curve_limiter_d2_(ComputeLimiterD2(compression_ratio_)),
|
||||
gain_curve_limiter_i1_(1.0 / compression_ratio_),
|
||||
gain_curve_limiter_i2_(ComputeLimiterI2(max_input_level_db_,
|
||||
compression_ratio_,
|
||||
gain_curve_limiter_i1_)) {
|
||||
static_assert(knee_smoothness_db_ > 0.0f, "");
|
||||
static_assert(compression_ratio_ > 1.0f, "");
|
||||
RTC_CHECK_GE(max_input_level_db_, knee_start_dbfs_ + knee_smoothness_db_);
|
||||
}
|
||||
|
||||
constexpr double LimiterDbGainCurve::max_input_level_db_;
|
||||
constexpr double LimiterDbGainCurve::knee_smoothness_db_;
|
||||
constexpr double LimiterDbGainCurve::compression_ratio_;
|
||||
|
||||
double LimiterDbGainCurve::GetOutputLevelDbfs(double input_level_dbfs) const {
|
||||
if (input_level_dbfs < knee_start_dbfs_) {
|
||||
return input_level_dbfs;
|
||||
} else if (input_level_dbfs < limiter_start_dbfs_) {
|
||||
return GetKneeRegionOutputLevelDbfs(input_level_dbfs);
|
||||
}
|
||||
return GetCompressorRegionOutputLevelDbfs(input_level_dbfs);
|
||||
}
|
||||
|
||||
double LimiterDbGainCurve::GetGainLinear(double input_level_linear) const {
|
||||
if (input_level_linear < knee_start_linear_) {
|
||||
return 1.0;
|
||||
}
|
||||
return DbfsToFloatS16(
|
||||
GetOutputLevelDbfs(FloatS16ToDbfs(input_level_linear))) /
|
||||
input_level_linear;
|
||||
}
|
||||
|
||||
// Computes the first derivative of GetGainLinear() in |x|.
|
||||
double LimiterDbGainCurve::GetGainFirstDerivativeLinear(double x) const {
|
||||
// Beyond-knee region only.
|
||||
RTC_CHECK_GE(x, limiter_start_linear_ - 1e-7 * kMaxAbsFloatS16Value);
|
||||
return gain_curve_limiter_d1_ *
|
||||
std::pow(x / kMaxAbsFloatS16Value, gain_curve_limiter_d2_);
|
||||
}
|
||||
|
||||
// Computes the integral of GetGainLinear() in the range [x0, x1].
|
||||
double LimiterDbGainCurve::GetGainIntegralLinear(double x0, double x1) const {
|
||||
RTC_CHECK_LE(x0, x1); // Valid interval.
|
||||
RTC_CHECK_GE(x0, limiter_start_linear_); // Beyond-knee region only.
|
||||
auto limiter_integral = [this](const double& x) {
|
||||
return gain_curve_limiter_i2_ * std::pow(x, gain_curve_limiter_i1_);
|
||||
};
|
||||
return limiter_integral(x1) - limiter_integral(x0);
|
||||
}
|
||||
|
||||
double LimiterDbGainCurve::GetKneeRegionOutputLevelDbfs(
|
||||
double input_level_dbfs) const {
|
||||
return knee_region_polynomial_[0] * input_level_dbfs * input_level_dbfs +
|
||||
knee_region_polynomial_[1] * input_level_dbfs +
|
||||
knee_region_polynomial_[2];
|
||||
}
|
||||
|
||||
double LimiterDbGainCurve::GetCompressorRegionOutputLevelDbfs(
|
||||
double input_level_dbfs) const {
|
||||
return (input_level_dbfs - max_input_level_db_) / compression_ratio_;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
76
webrtc/modules/audio_processing/agc2/limiter_db_gain_curve.h
Normal file
76
webrtc/modules/audio_processing/agc2/limiter_db_gain_curve.h
Normal file
@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_testing_common.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// A class for computing a limiter gain curve (in dB scale) given a set of
|
||||
// hard-coded parameters (namely, kLimiterDbGainCurveMaxInputLevelDbFs,
|
||||
// kLimiterDbGainCurveKneeSmoothnessDb, and
|
||||
// kLimiterDbGainCurveCompressionRatio). The generated curve consists of four
|
||||
// regions: identity (linear), knee (quadratic polynomial), compression
|
||||
// (linear), saturation (linear). The aforementioned constants are used to shape
|
||||
// the different regions.
|
||||
class LimiterDbGainCurve {
|
||||
public:
|
||||
LimiterDbGainCurve();
|
||||
|
||||
double max_input_level_db() const { return max_input_level_db_; }
|
||||
double max_input_level_linear() const { return max_input_level_linear_; }
|
||||
double knee_start_linear() const { return knee_start_linear_; }
|
||||
double limiter_start_linear() const { return limiter_start_linear_; }
|
||||
|
||||
// These methods can be marked 'constexpr' in C++ 14.
|
||||
double GetOutputLevelDbfs(double input_level_dbfs) const;
|
||||
double GetGainLinear(double input_level_linear) const;
|
||||
double GetGainFirstDerivativeLinear(double x) const;
|
||||
double GetGainIntegralLinear(double x0, double x1) const;
|
||||
|
||||
private:
|
||||
double GetKneeRegionOutputLevelDbfs(double input_level_dbfs) const;
|
||||
double GetCompressorRegionOutputLevelDbfs(double input_level_dbfs) const;
|
||||
|
||||
static constexpr double max_input_level_db_ = test::kLimiterMaxInputLevelDbFs;
|
||||
static constexpr double knee_smoothness_db_ = test::kLimiterKneeSmoothnessDb;
|
||||
static constexpr double compression_ratio_ = test::kLimiterCompressionRatio;
|
||||
|
||||
const double max_input_level_linear_;
|
||||
|
||||
// Do not modify signal with level <= knee_start_dbfs_.
|
||||
const double knee_start_dbfs_;
|
||||
const double knee_start_linear_;
|
||||
|
||||
// The upper end of the knee region, which is between knee_start_dbfs_ and
|
||||
// limiter_start_dbfs_.
|
||||
const double limiter_start_dbfs_;
|
||||
const double limiter_start_linear_;
|
||||
|
||||
// Coefficients {a, b, c} of the knee region polynomial
|
||||
// ax^2 + bx + c in the DB scale.
|
||||
const std::array<double, 3> knee_region_polynomial_;
|
||||
|
||||
// Parameters for the computation of the first derivative of GetGainLinear().
|
||||
const double gain_curve_limiter_d1_;
|
||||
const double gain_curve_limiter_d2_;
|
||||
|
||||
// Parameters for the computation of the integral of GetGainLinear().
|
||||
const double gain_curve_limiter_i1_;
|
||||
const double gain_curve_limiter_i2_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
|
114
webrtc/modules/audio_processing/agc2/noise_level_estimator.cc
Normal file
114
webrtc/modules/audio_processing/agc2/noise_level_estimator.cc
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/noise_level_estimator.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace {
|
||||
constexpr int kFramesPerSecond = 100;
|
||||
|
||||
float FrameEnergy(const AudioFrameView<const float>& audio) {
|
||||
float energy = 0.f;
|
||||
for (size_t k = 0; k < audio.num_channels(); ++k) {
|
||||
float channel_energy =
|
||||
std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.f,
|
||||
[](float a, float b) -> float { return a + b * b; });
|
||||
energy = std::max(channel_energy, energy);
|
||||
}
|
||||
return energy;
|
||||
}
|
||||
|
||||
float EnergyToDbfs(float signal_energy, size_t num_samples) {
|
||||
const float rms = std::sqrt(signal_energy / num_samples);
|
||||
return FloatS16ToDbfs(rms);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper)
|
||||
: signal_classifier_(data_dumper) {
|
||||
Initialize(48000);
|
||||
}
|
||||
|
||||
NoiseLevelEstimator::~NoiseLevelEstimator() {}
|
||||
|
||||
void NoiseLevelEstimator::Initialize(int sample_rate_hz) {
|
||||
sample_rate_hz_ = sample_rate_hz;
|
||||
noise_energy_ = 1.f;
|
||||
first_update_ = true;
|
||||
min_noise_energy_ = sample_rate_hz * 2.f * 2.f / kFramesPerSecond;
|
||||
noise_energy_hold_counter_ = 0;
|
||||
signal_classifier_.Initialize(sample_rate_hz);
|
||||
}
|
||||
|
||||
float NoiseLevelEstimator::Analyze(const AudioFrameView<const float>& frame) {
|
||||
const int rate =
|
||||
static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
|
||||
if (rate != sample_rate_hz_) {
|
||||
Initialize(rate);
|
||||
}
|
||||
const float frame_energy = FrameEnergy(frame);
|
||||
if (frame_energy <= 0.f) {
|
||||
RTC_DCHECK_GE(frame_energy, 0.f);
|
||||
return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
|
||||
}
|
||||
|
||||
if (first_update_) {
|
||||
// Initialize the noise energy to the frame energy.
|
||||
first_update_ = false;
|
||||
return EnergyToDbfs(
|
||||
noise_energy_ = std::max(frame_energy, min_noise_energy_),
|
||||
frame.samples_per_channel());
|
||||
}
|
||||
|
||||
const SignalClassifier::SignalType signal_type =
|
||||
signal_classifier_.Analyze(frame.channel(0));
|
||||
|
||||
// Update the noise estimate in a minimum statistics-type manner.
|
||||
if (signal_type == SignalClassifier::SignalType::kStationary) {
|
||||
if (frame_energy > noise_energy_) {
|
||||
// Leak the estimate upwards towards the frame energy if no recent
|
||||
// downward update.
|
||||
noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0);
|
||||
|
||||
if (noise_energy_hold_counter_ == 0) {
|
||||
noise_energy_ = std::min(noise_energy_ * 1.01f, frame_energy);
|
||||
}
|
||||
} else {
|
||||
// Update smoothly downwards with a limited maximum update magnitude.
|
||||
noise_energy_ =
|
||||
std::max(noise_energy_ * 0.9f,
|
||||
noise_energy_ + 0.05f * (frame_energy - noise_energy_));
|
||||
noise_energy_hold_counter_ = 1000;
|
||||
}
|
||||
} else {
|
||||
// For a non-stationary signal, leak the estimate downwards in order to
|
||||
// avoid estimate locking due to incorrect signal classification.
|
||||
noise_energy_ = noise_energy_ * 0.99f;
|
||||
}
|
||||
|
||||
// Ensure a minimum of the estimate.
|
||||
return EnergyToDbfs(
|
||||
noise_energy_ = std::max(noise_energy_, min_noise_energy_),
|
||||
frame.samples_per_channel());
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
43
webrtc/modules/audio_processing/agc2/noise_level_estimator.h
Normal file
43
webrtc/modules/audio_processing/agc2/noise_level_estimator.h
Normal file
@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
|
||||
|
||||
#include "modules/audio_processing/agc2/signal_classifier.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
#include "rtc_base/constructor_magic.h"
|
||||
|
||||
namespace webrtc {
|
||||
class ApmDataDumper;
|
||||
|
||||
class NoiseLevelEstimator {
|
||||
public:
|
||||
NoiseLevelEstimator(ApmDataDumper* data_dumper);
|
||||
~NoiseLevelEstimator();
|
||||
// Returns the estimated noise level in dBFS.
|
||||
float Analyze(const AudioFrameView<const float>& frame);
|
||||
|
||||
private:
|
||||
void Initialize(int sample_rate_hz);
|
||||
|
||||
int sample_rate_hz_;
|
||||
float min_noise_energy_;
|
||||
bool first_update_;
|
||||
float noise_energy_;
|
||||
int noise_energy_hold_counter_;
|
||||
SignalClassifier signal_classifier_;
|
||||
|
||||
RTC_DISALLOW_COPY_AND_ASSIGN(NoiseLevelEstimator);
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
|
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/arraysize.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
constexpr float kMinNoisePower = 100.f;
|
||||
} // namespace
|
||||
|
||||
NoiseSpectrumEstimator::NoiseSpectrumEstimator(ApmDataDumper* data_dumper)
|
||||
: data_dumper_(data_dumper) {
|
||||
Initialize();
|
||||
}
|
||||
|
||||
void NoiseSpectrumEstimator::Initialize() {
|
||||
std::fill(noise_spectrum_, noise_spectrum_ + arraysize(noise_spectrum_),
|
||||
kMinNoisePower);
|
||||
}
|
||||
|
||||
void NoiseSpectrumEstimator::Update(rtc::ArrayView<const float> spectrum,
|
||||
bool first_update) {
|
||||
RTC_DCHECK_EQ(65, spectrum.size());
|
||||
|
||||
if (first_update) {
|
||||
// Initialize the noise spectral estimate with the signal spectrum.
|
||||
std::copy(spectrum.data(), spectrum.data() + spectrum.size(),
|
||||
noise_spectrum_);
|
||||
} else {
|
||||
// Smoothly update the noise spectral estimate towards the signal spectrum
|
||||
// such that the magnitude of the updates are limited.
|
||||
for (size_t k = 0; k < spectrum.size(); ++k) {
|
||||
if (noise_spectrum_[k] < spectrum[k]) {
|
||||
noise_spectrum_[k] = std::min(
|
||||
1.01f * noise_spectrum_[k],
|
||||
noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k]));
|
||||
} else {
|
||||
noise_spectrum_[k] = std::max(
|
||||
0.99f * noise_spectrum_[k],
|
||||
noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that the noise spectal estimate does not become too low.
|
||||
for (auto& v : noise_spectrum_) {
|
||||
v = std::max(v, kMinNoisePower);
|
||||
}
|
||||
|
||||
data_dumper_->DumpRaw("lc_noise_spectrum", 65, noise_spectrum_);
|
||||
data_dumper_->DumpRaw("lc_signal_spectrum", spectrum);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_
|
||||
|
||||
#include "api/array_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class ApmDataDumper;
|
||||
|
||||
class NoiseSpectrumEstimator {
|
||||
public:
|
||||
explicit NoiseSpectrumEstimator(ApmDataDumper* data_dumper);
|
||||
|
||||
NoiseSpectrumEstimator() = delete;
|
||||
NoiseSpectrumEstimator(const NoiseSpectrumEstimator&) = delete;
|
||||
NoiseSpectrumEstimator& operator=(const NoiseSpectrumEstimator&) = delete;
|
||||
|
||||
void Initialize();
|
||||
void Update(rtc::ArrayView<const float> spectrum, bool first_update);
|
||||
|
||||
rtc::ArrayView<const float> GetNoiseSpectrum() const {
|
||||
return rtc::ArrayView<const float>(noise_spectrum_);
|
||||
}
|
||||
|
||||
private:
|
||||
ApmDataDumper* data_dumper_;
|
||||
float noise_spectrum_[65];
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_
|
233
webrtc/modules/audio_processing/agc2/rnn_vad/BUILD.gn
Normal file
233
webrtc/modules/audio_processing/agc2/rnn_vad/BUILD.gn
Normal file
@ -0,0 +1,233 @@
|
||||
# Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
import("../../../../webrtc.gni")
|
||||
|
||||
rtc_library("rnn_vad") {
|
||||
visibility = [ "../*" ]
|
||||
sources = [
|
||||
"features_extraction.cc",
|
||||
"features_extraction.h",
|
||||
"rnn.cc",
|
||||
"rnn.h",
|
||||
]
|
||||
|
||||
if (rtc_build_with_neon && current_cpu != "arm64") {
|
||||
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
|
||||
cflags = [ "-mfpu=neon" ]
|
||||
}
|
||||
|
||||
deps = [
|
||||
":rnn_vad_common",
|
||||
":rnn_vad_lp_residual",
|
||||
":rnn_vad_pitch",
|
||||
":rnn_vad_sequence_buffer",
|
||||
":rnn_vad_spectral_features",
|
||||
"..:biquad_filter",
|
||||
"../../../../api:array_view",
|
||||
"../../../../api:function_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_auto_correlation") {
|
||||
sources = [
|
||||
"auto_correlation.cc",
|
||||
"auto_correlation.h",
|
||||
]
|
||||
deps = [
|
||||
":rnn_vad_common",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../utility:pffft_wrapper",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_common") {
|
||||
# TODO(alessiob): Make this target visibility private.
|
||||
visibility = [
|
||||
":*",
|
||||
"..:rnn_vad_with_level",
|
||||
]
|
||||
sources = [
|
||||
"common.cc",
|
||||
"common.h",
|
||||
]
|
||||
deps = [
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../system_wrappers",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_lp_residual") {
|
||||
sources = [
|
||||
"lp_residual.cc",
|
||||
"lp_residual.h",
|
||||
]
|
||||
deps = [
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_pitch") {
|
||||
sources = [
|
||||
"pitch_info.h",
|
||||
"pitch_search.cc",
|
||||
"pitch_search.h",
|
||||
"pitch_search_internal.cc",
|
||||
"pitch_search_internal.h",
|
||||
]
|
||||
deps = [
|
||||
":rnn_vad_auto_correlation",
|
||||
":rnn_vad_common",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("rnn_vad_ring_buffer") {
|
||||
sources = [ "ring_buffer.h" ]
|
||||
deps = [
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("rnn_vad_sequence_buffer") {
|
||||
sources = [ "sequence_buffer.h" ]
|
||||
deps = [
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_spectral_features") {
|
||||
sources = [
|
||||
"spectral_features.cc",
|
||||
"spectral_features.h",
|
||||
"spectral_features_internal.cc",
|
||||
"spectral_features_internal.h",
|
||||
]
|
||||
deps = [
|
||||
":rnn_vad_common",
|
||||
":rnn_vad_ring_buffer",
|
||||
":rnn_vad_symmetric_matrix_buffer",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../utility:pffft_wrapper",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("rnn_vad_symmetric_matrix_buffer") {
|
||||
sources = [ "symmetric_matrix_buffer.h" ]
|
||||
deps = [
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
]
|
||||
}
|
||||
|
||||
if (rtc_include_tests) {
|
||||
rtc_library("test_utils") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"test_utils.cc",
|
||||
"test_utils.h",
|
||||
]
|
||||
deps = [
|
||||
":rnn_vad",
|
||||
":rnn_vad_common",
|
||||
"../../../../api:array_view",
|
||||
"../../../../api:scoped_refptr",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../system_wrappers",
|
||||
"../../../../test:fileutils",
|
||||
"../../../../test:test_support",
|
||||
]
|
||||
}
|
||||
|
||||
unittest_resources = [
|
||||
"../../../../resources/audio_processing/agc2/rnn_vad/band_energies.dat",
|
||||
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat",
|
||||
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat",
|
||||
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_search_int.dat",
|
||||
"../../../../resources/audio_processing/agc2/rnn_vad/samples.pcm",
|
||||
"../../../../resources/audio_processing/agc2/rnn_vad/vad_prob.dat",
|
||||
]
|
||||
|
||||
if (is_ios) {
|
||||
bundle_data("unittests_bundle_data") {
|
||||
testonly = true
|
||||
sources = unittest_resources
|
||||
outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ]
|
||||
}
|
||||
}
|
||||
|
||||
rtc_library("unittests") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"auto_correlation_unittest.cc",
|
||||
"features_extraction_unittest.cc",
|
||||
"lp_residual_unittest.cc",
|
||||
"pitch_search_internal_unittest.cc",
|
||||
"pitch_search_unittest.cc",
|
||||
"ring_buffer_unittest.cc",
|
||||
"rnn_unittest.cc",
|
||||
"rnn_vad_unittest.cc",
|
||||
"sequence_buffer_unittest.cc",
|
||||
"spectral_features_internal_unittest.cc",
|
||||
"spectral_features_unittest.cc",
|
||||
"symmetric_matrix_buffer_unittest.cc",
|
||||
]
|
||||
deps = [
|
||||
":rnn_vad",
|
||||
":rnn_vad_auto_correlation",
|
||||
":rnn_vad_common",
|
||||
":rnn_vad_lp_residual",
|
||||
":rnn_vad_pitch",
|
||||
":rnn_vad_ring_buffer",
|
||||
":rnn_vad_sequence_buffer",
|
||||
":rnn_vad_spectral_features",
|
||||
":rnn_vad_symmetric_matrix_buffer",
|
||||
":test_utils",
|
||||
"../..:audioproc_test_utils",
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio/",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../test:test_support",
|
||||
"../../utility:pffft_wrapper",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
]
|
||||
absl_deps = [ "//third_party/abseil-cpp/absl/memory" ]
|
||||
data = unittest_resources
|
||||
if (is_ios) {
|
||||
deps += [ ":unittests_bundle_data" ]
|
||||
}
|
||||
}
|
||||
|
||||
rtc_executable("rnn_vad_tool") {
|
||||
testonly = true
|
||||
sources = [ "rnn_vad_tool.cc" ]
|
||||
deps = [
|
||||
":rnn_vad",
|
||||
":rnn_vad_common",
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio",
|
||||
"../../../../rtc_base:rtc_base_approved",
|
||||
"../../../../test:test_support",
|
||||
"//third_party/abseil-cpp/absl/flags:flag",
|
||||
"//third_party/abseil-cpp/absl/flags:parse",
|
||||
]
|
||||
}
|
||||
}
|
@ -0,0 +1,92 @@
|
||||
/*
|
||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
|
||||
static_assert(1 << kAutoCorrelationFftOrder >
|
||||
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
|
||||
"");
|
||||
|
||||
} // namespace
|
||||
|
||||
AutoCorrelationCalculator::AutoCorrelationCalculator()
|
||||
: fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
|
||||
tmp_(fft_.CreateBuffer()),
|
||||
X_(fft_.CreateBuffer()),
|
||||
H_(fft_.CreateBuffer()) {}
|
||||
|
||||
AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
|
||||
|
||||
// The auto-correlations coefficients are computed as follows:
|
||||
// |.........|...........| <- pitch buffer
|
||||
// [ x (fixed) ]
|
||||
// [ y_0 ]
|
||||
// [ y_{m-1} ]
|
||||
// x and y are sub-array of equal length; x is never moved, whereas y slides.
|
||||
// The cross-correlation between y_0 and x corresponds to the auto-correlation
|
||||
// for the maximum pitch period. Hence, the first value in |auto_corr| has an
|
||||
// inverted lag equal to 0 that corresponds to a lag equal to the maximum
|
||||
// pitch period.
|
||||
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
|
||||
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
|
||||
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
|
||||
constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder;
|
||||
constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
|
||||
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
|
||||
"Mismatch between pitch buffer size, frame size and maximum "
|
||||
"pitch period.");
|
||||
static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
|
||||
"The FFT length is not sufficiently big to avoid cyclic "
|
||||
"convolution errors.");
|
||||
auto tmp = tmp_->GetView();
|
||||
|
||||
// Compute the FFT for the reversed reference frame - i.e.,
|
||||
// pitch_buf[-kConvolutionLength:].
|
||||
std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
|
||||
tmp.begin());
|
||||
std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
|
||||
fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
|
||||
|
||||
// Compute the FFT for the sliding frames chunk. The sliding frames are
|
||||
// defined as pitch_buf[i:i+kConvolutionLength] where i in
|
||||
// [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
|
||||
// defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
|
||||
std::copy(pitch_buf.begin(),
|
||||
pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
|
||||
tmp.begin());
|
||||
std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
|
||||
0.f);
|
||||
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
|
||||
|
||||
// Convolve in the frequency domain.
|
||||
constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
|
||||
std::fill(tmp.begin(), tmp.end(), 0.f);
|
||||
fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
|
||||
fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
|
||||
|
||||
// Extract the auto-correlation coefficients.
|
||||
std::copy(tmp.begin() + kConvolutionLength - 1,
|
||||
tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
|
||||
auto_corr.begin());
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/utility/pffft_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Class to compute the auto correlation on the pitch buffer for a target pitch
|
||||
// interval.
|
||||
class AutoCorrelationCalculator {
|
||||
public:
|
||||
AutoCorrelationCalculator();
|
||||
AutoCorrelationCalculator(const AutoCorrelationCalculator&) = delete;
|
||||
AutoCorrelationCalculator& operator=(const AutoCorrelationCalculator&) =
|
||||
delete;
|
||||
~AutoCorrelationCalculator();
|
||||
|
||||
// Computes the auto-correlation coefficients for a target pitch interval.
|
||||
// |auto_corr| indexes are inverted lags.
|
||||
void ComputeOnPitchBuffer(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
|
||||
|
||||
private:
|
||||
Pffft fft_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> tmp_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> X_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> H_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
|
34
webrtc/modules/audio_processing/agc2/rnn_vad/common.cc
Normal file
34
webrtc/modules/audio_processing/agc2/rnn_vad/common.cc
Normal file
@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
Optimization DetectOptimization() {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
if (GetCPUInfo(kSSE2) != 0) {
|
||||
return Optimization::kSse2;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
return Optimization::kNeon;
|
||||
#endif
|
||||
|
||||
return Optimization::kNone;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
76
webrtc/modules/audio_processing/agc2/rnn_vad/common.h
Normal file
76
webrtc/modules/audio_processing/agc2/rnn_vad/common.h
Normal file
@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
constexpr double kPi = 3.14159265358979323846;
|
||||
|
||||
constexpr size_t kSampleRate24kHz = 24000;
|
||||
constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100;
|
||||
constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
|
||||
|
||||
// Pitch buffer.
|
||||
constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
|
||||
constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
|
||||
constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
|
||||
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
|
||||
|
||||
// 24 kHz analysis.
|
||||
// Define a higher minimum pitch period for the initial search. This is used to
|
||||
// avoid searching for very short periods, for which a refinement step is
|
||||
// responsible.
|
||||
constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
|
||||
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
|
||||
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
|
||||
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
|
||||
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
|
||||
|
||||
// 12 kHz analysis.
|
||||
constexpr size_t kSampleRate12kHz = 12000;
|
||||
constexpr size_t kFrameSize10ms12kHz = kSampleRate12kHz / 100;
|
||||
constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
|
||||
constexpr size_t kBufSize12kHz = kBufSize24kHz / 2;
|
||||
constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
|
||||
constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
|
||||
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
|
||||
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
|
||||
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
|
||||
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
|
||||
|
||||
// 48 kHz constants.
|
||||
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
|
||||
constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2;
|
||||
|
||||
// Spectral features.
|
||||
constexpr size_t kNumBands = 22;
|
||||
constexpr size_t kNumLowerBands = 6;
|
||||
static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), "");
|
||||
constexpr size_t kCepstralCoeffsHistorySize = 8;
|
||||
static_assert(kCepstralCoeffsHistorySize > 2,
|
||||
"The history size must at least be 3 to compute first and second "
|
||||
"derivatives.");
|
||||
|
||||
constexpr size_t kFeatureVectorSize = 42;
|
||||
|
||||
enum class Optimization { kNone, kSse2, kNeon };
|
||||
|
||||
// Detects what kind of optimizations to use for the code.
|
||||
Optimization DetectOptimization();
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
@ -0,0 +1,90 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Generated via "B, A = scipy.signal.butter(2, 30/12000, btype='highpass')"
|
||||
const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = {
|
||||
{0.99446179f, -1.98892358f, 0.99446179f},
|
||||
{-1.98889291f, 0.98895425f}};
|
||||
|
||||
} // namespace
|
||||
|
||||
FeaturesExtractor::FeaturesExtractor()
|
||||
: use_high_pass_filter_(false),
|
||||
pitch_buf_24kHz_(),
|
||||
pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()),
|
||||
lp_residual_(kBufSize24kHz),
|
||||
lp_residual_view_(lp_residual_.data(), kBufSize24kHz),
|
||||
pitch_estimator_(),
|
||||
reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) {
|
||||
RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size());
|
||||
hpf_.Initialize(kHpfConfig24k);
|
||||
Reset();
|
||||
}
|
||||
|
||||
FeaturesExtractor::~FeaturesExtractor() = default;
|
||||
|
||||
void FeaturesExtractor::Reset() {
|
||||
pitch_buf_24kHz_.Reset();
|
||||
spectral_features_extractor_.Reset();
|
||||
if (use_high_pass_filter_)
|
||||
hpf_.Reset();
|
||||
}
|
||||
|
||||
bool FeaturesExtractor::CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
|
||||
// Pre-processing.
|
||||
if (use_high_pass_filter_) {
|
||||
std::array<float, kFrameSize10ms24kHz> samples_filtered;
|
||||
hpf_.Process(samples, samples_filtered);
|
||||
// Feed buffer with the pre-processed version of |samples|.
|
||||
pitch_buf_24kHz_.Push(samples_filtered);
|
||||
} else {
|
||||
// Feed buffer with |samples|.
|
||||
pitch_buf_24kHz_.Push(samples);
|
||||
}
|
||||
// Extract the LP residual.
|
||||
float lpc_coeffs[kNumLpcCoefficients];
|
||||
ComputeAndPostProcessLpcCoefficients(pitch_buf_24kHz_view_, lpc_coeffs);
|
||||
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
|
||||
// Estimate pitch on the LP-residual and write the normalized pitch period
|
||||
// into the output vector (normalization based on training data stats).
|
||||
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
|
||||
feature_vector[kFeatureVectorSize - 2] =
|
||||
0.01f * (static_cast<int>(pitch_info_48kHz_.period) - 300);
|
||||
// Extract lagged frames (according to the estimated pitch period).
|
||||
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
|
||||
auto lagged_frame = pitch_buf_24kHz_view_.subview(
|
||||
kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz);
|
||||
// Analyze reference and lagged frames checking if silence has been detected
|
||||
// and write the feature vector.
|
||||
return spectral_features_extractor_.CheckSilenceComputeFeatures(
|
||||
reference_frame_view_, {lagged_frame.data(), kFrameSize20ms24kHz},
|
||||
{feature_vector.data() + kNumLowerBands, kNumBands - kNumLowerBands},
|
||||
{feature_vector.data(), kNumLowerBands},
|
||||
{feature_vector.data() + kNumBands, kNumLowerBands},
|
||||
{feature_vector.data() + kNumBands + kNumLowerBands, kNumLowerBands},
|
||||
{feature_vector.data() + kNumBands + 2 * kNumLowerBands, kNumLowerBands},
|
||||
&feature_vector[kFeatureVectorSize - 1]);
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
@ -0,0 +1,62 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/biquad_filter.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Feature extractor to feed the VAD RNN.
|
||||
class FeaturesExtractor {
|
||||
public:
|
||||
FeaturesExtractor();
|
||||
FeaturesExtractor(const FeaturesExtractor&) = delete;
|
||||
FeaturesExtractor& operator=(const FeaturesExtractor&) = delete;
|
||||
~FeaturesExtractor();
|
||||
void Reset();
|
||||
// Analyzes the samples, computes the feature vector and returns true if
|
||||
// silence is detected (false if not). When silence is detected,
|
||||
// |feature_vector| is partially written and therefore must not be used to
|
||||
// feed the VAD RNN.
|
||||
bool CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector);
|
||||
|
||||
private:
|
||||
const bool use_high_pass_filter_;
|
||||
// TODO(bugs.webrtc.org/7494): Remove HPF depending on how AGC2 is used in APM
|
||||
// and on whether an HPF is already used as pre-processing step in APM.
|
||||
BiQuadFilter hpf_;
|
||||
SequenceBuffer<float, kBufSize24kHz, kFrameSize10ms24kHz, kFrameSize20ms24kHz>
|
||||
pitch_buf_24kHz_;
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf_24kHz_view_;
|
||||
std::vector<float> lp_residual_;
|
||||
rtc::ArrayView<float, kBufSize24kHz> lp_residual_view_;
|
||||
PitchEstimator pitch_estimator_;
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
|
||||
SpectralFeaturesExtractor spectral_features_extractor_;
|
||||
PitchInfo pitch_info_48kHz_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
|
138
webrtc/modules/audio_processing/agc2/rnn_vad/lp_residual.cc
Normal file
138
webrtc/modules/audio_processing/agc2/rnn_vad/lp_residual.cc
Normal file
@ -0,0 +1,138 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Computes cross-correlation coefficients between |x| and |y| and writes them
|
||||
// in |x_corr|. The lag values are in {0, ..., max_lag - 1}, where max_lag
|
||||
// equals the size of |x_corr|.
|
||||
// The |x| and |y| sub-arrays used to compute a cross-correlation coefficients
|
||||
// for a lag l have both size "size of |x| - l" - i.e., the longest sub-array is
|
||||
// used. |x| and |y| must have the same size.
|
||||
void ComputeCrossCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> x_corr) {
|
||||
constexpr size_t max_lag = x_corr.size();
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
RTC_DCHECK_LT(max_lag, x.size());
|
||||
for (size_t lag = 0; lag < max_lag; ++lag) {
|
||||
x_corr[lag] =
|
||||
std::inner_product(x.begin(), x.end() - lag, y.begin() + lag, 0.f);
|
||||
}
|
||||
}
|
||||
|
||||
// Applies denoising to the auto-correlation coefficients.
|
||||
void DenoiseAutoCorrelation(
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
|
||||
// Assume -40 dB white noise floor.
|
||||
auto_corr[0] *= 1.0001f;
|
||||
for (size_t i = 1; i < kNumLpcCoefficients; ++i) {
|
||||
auto_corr[i] -= auto_corr[i] * (0.008f * i) * (0.008f * i);
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the initial inverse filter coefficients given the auto-correlation
|
||||
// coefficients of an input frame.
|
||||
void ComputeInitialInverseFilterCoefficients(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> auto_corr,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients - 1> lpc_coeffs) {
|
||||
float error = auto_corr[0];
|
||||
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
|
||||
float reflection_coeff = 0.f;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
reflection_coeff += lpc_coeffs[j] * auto_corr[i - j];
|
||||
}
|
||||
reflection_coeff += auto_corr[i + 1];
|
||||
|
||||
// Avoid division by numbers close to zero.
|
||||
constexpr float kMinErrorMagnitude = 1e-6f;
|
||||
if (std::fabs(error) < kMinErrorMagnitude) {
|
||||
error = std::copysign(kMinErrorMagnitude, error);
|
||||
}
|
||||
|
||||
reflection_coeff /= -error;
|
||||
// Update LPC coefficients and total error.
|
||||
lpc_coeffs[i] = reflection_coeff;
|
||||
for (size_t j = 0; j<(i + 1)>> 1; ++j) {
|
||||
const float tmp1 = lpc_coeffs[j];
|
||||
const float tmp2 = lpc_coeffs[i - 1 - j];
|
||||
lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2;
|
||||
lpc_coeffs[i - 1 - j] = tmp2 + reflection_coeff * tmp1;
|
||||
}
|
||||
error -= reflection_coeff * reflection_coeff * error;
|
||||
if (error < 0.001f * auto_corr[0]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ComputeAndPostProcessLpcCoefficients(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs) {
|
||||
std::array<float, kNumLpcCoefficients> auto_corr;
|
||||
ComputeCrossCorrelation(x, x, {auto_corr.data(), auto_corr.size()});
|
||||
if (auto_corr[0] == 0.f) { // Empty frame.
|
||||
std::fill(lpc_coeffs.begin(), lpc_coeffs.end(), 0);
|
||||
return;
|
||||
}
|
||||
DenoiseAutoCorrelation({auto_corr.data(), auto_corr.size()});
|
||||
std::array<float, kNumLpcCoefficients - 1> lpc_coeffs_pre{};
|
||||
ComputeInitialInverseFilterCoefficients(auto_corr, lpc_coeffs_pre);
|
||||
// LPC coefficients post-processing.
|
||||
// TODO(bugs.webrtc.org/9076): Consider removing these steps.
|
||||
float c1 = 1.f;
|
||||
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
|
||||
c1 *= 0.9f;
|
||||
lpc_coeffs_pre[i] *= c1;
|
||||
}
|
||||
const float c2 = 0.8f;
|
||||
lpc_coeffs[0] = lpc_coeffs_pre[0] + c2;
|
||||
lpc_coeffs[1] = lpc_coeffs_pre[1] + c2 * lpc_coeffs_pre[0];
|
||||
lpc_coeffs[2] = lpc_coeffs_pre[2] + c2 * lpc_coeffs_pre[1];
|
||||
lpc_coeffs[3] = lpc_coeffs_pre[3] + c2 * lpc_coeffs_pre[2];
|
||||
lpc_coeffs[4] = c2 * lpc_coeffs_pre[3];
|
||||
}
|
||||
|
||||
void ComputeLpResidual(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> y) {
|
||||
RTC_DCHECK_LT(kNumLpcCoefficients, x.size());
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
std::array<float, kNumLpcCoefficients> input_chunk;
|
||||
input_chunk.fill(0.f);
|
||||
for (size_t i = 0; i < y.size(); ++i) {
|
||||
const float sum = std::inner_product(input_chunk.begin(), input_chunk.end(),
|
||||
lpc_coeffs.begin(), x[i]);
|
||||
// Circular shift and add a new sample.
|
||||
for (size_t j = kNumLpcCoefficients - 1; j > 0; --j)
|
||||
input_chunk[j] = input_chunk[j - 1];
|
||||
input_chunk[0] = x[i];
|
||||
// Copy result.
|
||||
y[i] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
41
webrtc/modules/audio_processing/agc2/rnn_vad/lp_residual.h
Normal file
41
webrtc/modules/audio_processing/agc2/rnn_vad/lp_residual.h
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "api/array_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// LPC inverse filter length.
|
||||
constexpr size_t kNumLpcCoefficients = 5;
|
||||
|
||||
// Given a frame |x|, computes a post-processed version of LPC coefficients
|
||||
// tailored for pitch estimation.
|
||||
void ComputeAndPostProcessLpcCoefficients(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs);
|
||||
|
||||
// Computes the LP residual for the input frame |x| and the LPC coefficients
|
||||
// |lpc_coeffs|. |y| and |x| can point to the same array for in-place
|
||||
// computation.
|
||||
void ComputeLpResidual(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> y);
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
|
29
webrtc/modules/audio_processing/agc2/rnn_vad/pitch_info.h
Normal file
29
webrtc/modules/audio_processing/agc2/rnn_vad/pitch_info.h
Normal file
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Stores pitch period and gain information. The pitch gain measures the
|
||||
// strength of the pitch (the higher, the stronger).
|
||||
struct PitchInfo {
|
||||
PitchInfo() : period(0), gain(0.f) {}
|
||||
PitchInfo(int p, float g) : period(p), gain(g) {}
|
||||
int period;
|
||||
float gain;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
|
56
webrtc/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
Normal file
56
webrtc/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
Normal file
@ -0,0 +1,56 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
PitchEstimator::PitchEstimator()
|
||||
: pitch_buf_decimated_(kBufSize12kHz),
|
||||
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
|
||||
auto_corr_(kNumInvertedLags12kHz),
|
||||
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
|
||||
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
|
||||
RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size());
|
||||
}
|
||||
|
||||
PitchEstimator::~PitchEstimator() = default;
|
||||
|
||||
PitchInfo PitchEstimator::Estimate(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
|
||||
// Perform the initial pitch search at 12 kHz.
|
||||
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
|
||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
|
||||
auto_corr_view_);
|
||||
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
|
||||
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
|
||||
// Refine the pitch period estimation.
|
||||
// The refinement is done using the pitch buffer that contains 24 kHz samples.
|
||||
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
|
||||
// to 24 kHz.
|
||||
pitch_candidates_inv_lags[0] *= 2;
|
||||
pitch_candidates_inv_lags[1] *= 2;
|
||||
size_t pitch_inv_lag_48kHz =
|
||||
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags);
|
||||
// Look for stronger harmonics to find the final pitch period and its gain.
|
||||
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
|
||||
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
|
||||
return last_pitch_48kHz_;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
49
webrtc/modules/audio_processing/agc2/rnn_vad/pitch_search.h
Normal file
49
webrtc/modules/audio_processing/agc2/rnn_vad/pitch_search.h
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Pitch estimator.
|
||||
class PitchEstimator {
|
||||
public:
|
||||
PitchEstimator();
|
||||
PitchEstimator(const PitchEstimator&) = delete;
|
||||
PitchEstimator& operator=(const PitchEstimator&) = delete;
|
||||
~PitchEstimator();
|
||||
// Estimates the pitch period and gain. Returns the pitch estimation data for
|
||||
// 48 kHz.
|
||||
PitchInfo Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf);
|
||||
|
||||
private:
|
||||
PitchInfo last_pitch_48kHz_;
|
||||
AutoCorrelationCalculator auto_corr_calculator_;
|
||||
std::vector<float> pitch_buf_decimated_;
|
||||
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
|
||||
std::vector<float> auto_corr_;
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr_view_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
|
@ -0,0 +1,403 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <numeric>
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Converts a lag to an inverted lag (only for 24kHz).
|
||||
size_t GetInvertedLag(size_t lag) {
|
||||
RTC_DCHECK_LE(lag, kMaxPitch24kHz);
|
||||
return kMaxPitch24kHz - lag;
|
||||
}
|
||||
|
||||
float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,
|
||||
size_t inv_lag,
|
||||
size_t max_pitch_period) {
|
||||
RTC_DCHECK_LT(inv_lag, pitch_buf.size());
|
||||
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
|
||||
RTC_DCHECK_LE(inv_lag, max_pitch_period);
|
||||
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
|
||||
return std::inner_product(pitch_buf.begin() + max_pitch_period,
|
||||
pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f);
|
||||
}
|
||||
|
||||
// Computes a pseudo-interpolation offset for an estimated pitch period |lag| by
|
||||
// looking at the auto-correlation coefficients in the neighborhood of |lag|.
|
||||
// (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output
|
||||
// is a lag in {-1, 0, +1}.
|
||||
// TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it
|
||||
// is relevant only if the spectral analysis works at a sample rate that is
|
||||
// twice as that of the pitch buffer (not so important instead for the estimated
|
||||
// pitch period feature fed into the RNN).
|
||||
int GetPitchPseudoInterpolationOffset(size_t lag,
|
||||
float prev_auto_corr,
|
||||
float lag_auto_corr,
|
||||
float next_auto_corr) {
|
||||
const float& a = prev_auto_corr;
|
||||
const float& b = lag_auto_corr;
|
||||
const float& c = next_auto_corr;
|
||||
|
||||
int offset = 0;
|
||||
if ((c - a) > 0.7f * (b - a)) {
|
||||
offset = 1; // |c| is the largest auto-correlation coefficient.
|
||||
} else if ((a - c) > 0.7f * (b - c)) {
|
||||
offset = -1; // |a| is the largest auto-correlation coefficient.
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
// Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The
|
||||
// output sample rate is twice as that of |lag|.
|
||||
size_t PitchPseudoInterpolationLagPitchBuf(
|
||||
size_t lag,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
|
||||
int offset = 0;
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
if (lag > 0 && lag < kMaxPitch24kHz) {
|
||||
offset = GetPitchPseudoInterpolationOffset(
|
||||
lag,
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1),
|
||||
kMaxPitch24kHz),
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag),
|
||||
kMaxPitch24kHz),
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1),
|
||||
kMaxPitch24kHz));
|
||||
}
|
||||
return 2 * lag + offset;
|
||||
}
|
||||
|
||||
// Refines a pitch period |inv_lag| encoded as inverted lag with
|
||||
// pseudo-interpolation. The output sample rate is twice as that of
|
||||
// |inv_lag|.
|
||||
size_t PitchPseudoInterpolationInvLagAutoCorr(
|
||||
size_t inv_lag,
|
||||
rtc::ArrayView<const float> auto_corr) {
|
||||
int offset = 0;
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) {
|
||||
offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1],
|
||||
auto_corr[inv_lag],
|
||||
auto_corr[inv_lag - 1]);
|
||||
}
|
||||
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
|
||||
// be subtracted since |inv_lag| is an inverted lag but offset is a lag.
|
||||
return 2 * inv_lag + offset;
|
||||
}
|
||||
|
||||
// Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when
|
||||
// looking for sub-harmonics.
|
||||
// The values have been chosen to serve the following algorithm. Given the
|
||||
// initial pitch period T, we examine whether one of its harmonics is the true
|
||||
// fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of
|
||||
// these harmonics, in addition to the pitch gain of itself, we choose one
|
||||
// multiple of its pitch period, n*T/k, to validate it (by averaging their pitch
|
||||
// gains). The multiplier n is chosen so that n*T/k is used only one time over
|
||||
// all k. When for example k = 4, we should also expect a peak at 3*T/4. When
|
||||
// k = 8 instead we don't want to look at 2*T/8, since we have already checked
|
||||
// T/4 before. Instead, we look at T*3/8.
|
||||
// The array can be generate in Python as follows:
|
||||
// from fractions import Fraction
|
||||
// # Smallest positive integer not in X.
|
||||
// def mex(X):
|
||||
// for i in range(1, int(max(X)+2)):
|
||||
// if i not in X:
|
||||
// return i
|
||||
// # Visited multiples of the period.
|
||||
// S = {1}
|
||||
// for n in range(2, 16):
|
||||
// sn = mex({n * i for i in S} | {1})
|
||||
// S = S | {Fraction(1, n), Fraction(sn, n)}
|
||||
// print(sn, end=', ')
|
||||
constexpr std::array<int, 14> kSubHarmonicMultipliers = {
|
||||
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
|
||||
|
||||
// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
|
||||
// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
|
||||
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
|
||||
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
|
||||
|
||||
} // namespace
|
||||
|
||||
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
rtc::ArrayView<float, kBufSize12kHz> dst) {
|
||||
// TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
|
||||
static_assert(2 * dst.size() == src.size(), "");
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
dst[i] = src[2 * i];
|
||||
}
|
||||
}
|
||||
|
||||
float ComputePitchGainThreshold(int candidate_pitch_period,
|
||||
int pitch_period_ratio,
|
||||
int initial_pitch_period,
|
||||
float initial_pitch_gain,
|
||||
int prev_pitch_period,
|
||||
float prev_pitch_gain) {
|
||||
// Map arguments to more compact aliases.
|
||||
const int& t1 = candidate_pitch_period;
|
||||
const int& k = pitch_period_ratio;
|
||||
const int& t0 = initial_pitch_period;
|
||||
const float& g0 = initial_pitch_gain;
|
||||
const int& t_prev = prev_pitch_period;
|
||||
const float& g_prev = prev_pitch_gain;
|
||||
|
||||
// Validate input.
|
||||
RTC_DCHECK_GE(t1, 0);
|
||||
RTC_DCHECK_GE(k, 2);
|
||||
RTC_DCHECK_GE(t0, 0);
|
||||
RTC_DCHECK_GE(t_prev, 0);
|
||||
|
||||
// Compute a term that lowers the threshold when |t1| is close to the last
|
||||
// estimated period |t_prev| - i.e., pitch tracking.
|
||||
float lower_threshold_term = 0;
|
||||
if (abs(t1 - t_prev) <= 1) {
|
||||
// The candidate pitch period is within 1 sample from the previous one.
|
||||
// Make the candidate at |t1| very easy to be accepted.
|
||||
lower_threshold_term = g_prev;
|
||||
} else if (abs(t1 - t_prev) == 2 &&
|
||||
t0 > kInitialPitchPeriodThresholds[k - 2]) {
|
||||
// The candidate pitch period is 2 samples far from the previous one and the
|
||||
// period |t0| (from which |t1| has been derived) is greater than a
|
||||
// threshold. Make |t1| easy to be accepted.
|
||||
lower_threshold_term = 0.5f * g_prev;
|
||||
}
|
||||
// Set the threshold based on the gain of the initial estimate |t0|. Also
|
||||
// reduce the chance of false positives caused by a bias towards high
|
||||
// frequencies (originating from short-term correlations).
|
||||
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
|
||||
if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
|
||||
// High frequency.
|
||||
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
|
||||
} else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
|
||||
// Even higher frequency.
|
||||
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
|
||||
}
|
||||
return threshold;
|
||||
}
|
||||
|
||||
void ComputeSlidingFrameSquareEnergies(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) {
|
||||
float yy =
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz);
|
||||
yy_values[0] = yy;
|
||||
for (size_t i = 1; i < yy_values.size(); ++i) {
|
||||
RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz);
|
||||
RTC_DCHECK_LE(i, kMaxPitch24kHz);
|
||||
const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i];
|
||||
const float new_coeff = pitch_buf[kMaxPitch24kHz - i];
|
||||
yy -= old_coeff * old_coeff;
|
||||
yy += new_coeff * new_coeff;
|
||||
yy = std::max(0.f, yy);
|
||||
yy_values[i] = yy;
|
||||
}
|
||||
}
|
||||
|
||||
std::array<size_t, 2> FindBestPitchPeriods(
|
||||
rtc::ArrayView<const float> auto_corr,
|
||||
rtc::ArrayView<const float> pitch_buf,
|
||||
size_t max_pitch_period) {
|
||||
// Stores a pitch candidate period and strength information.
|
||||
struct PitchCandidate {
|
||||
// Pitch period encoded as inverted lag.
|
||||
size_t period_inverted_lag = 0;
|
||||
// Pitch strength encoded as a ratio.
|
||||
float strength_numerator = -1.f;
|
||||
float strength_denominator = 0.f;
|
||||
// Compare the strength of two pitch candidates.
|
||||
bool HasStrongerPitchThan(const PitchCandidate& b) const {
|
||||
// Comparing the numerator/denominator ratios without using divisions.
|
||||
return strength_numerator * b.strength_denominator >
|
||||
b.strength_numerator * strength_denominator;
|
||||
}
|
||||
};
|
||||
|
||||
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
|
||||
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
|
||||
const size_t frame_size = pitch_buf.size() - max_pitch_period;
|
||||
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
|
||||
float yy =
|
||||
std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1,
|
||||
pitch_buf.begin(), 1.f);
|
||||
// Search best and second best pitches by looking at the scaled
|
||||
// auto-correlation.
|
||||
PitchCandidate candidate;
|
||||
PitchCandidate best;
|
||||
PitchCandidate second_best;
|
||||
second_best.period_inverted_lag = 1;
|
||||
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
|
||||
// A pitch candidate must have positive correlation.
|
||||
if (auto_corr[inv_lag] > 0) {
|
||||
candidate.period_inverted_lag = inv_lag;
|
||||
candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag];
|
||||
candidate.strength_denominator = yy;
|
||||
if (candidate.HasStrongerPitchThan(second_best)) {
|
||||
if (candidate.HasStrongerPitchThan(best)) {
|
||||
second_best = best;
|
||||
best = candidate;
|
||||
} else {
|
||||
second_best = candidate;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update |squared_energy_y| for the next inverted lag.
|
||||
const float old_coeff = pitch_buf[inv_lag];
|
||||
const float new_coeff = pitch_buf[inv_lag + frame_size];
|
||||
yy -= old_coeff * old_coeff;
|
||||
yy += new_coeff * new_coeff;
|
||||
yy = std::max(0.f, yy);
|
||||
}
|
||||
return {{best.period_inverted_lag, second_best.period_inverted_lag}};
|
||||
}
|
||||
|
||||
size_t RefinePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
rtc::ArrayView<const size_t, 2> inv_lags) {
|
||||
// Compute the auto-correlation terms only for neighbors of the given pitch
|
||||
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
|
||||
// for a few lag values).
|
||||
std::array<float, kNumInvertedLags24kHz> auto_corr;
|
||||
auto_corr.fill(0.f); // Zeros become ignored lags in FindBestPitchPeriods().
|
||||
auto is_neighbor = [](size_t i, size_t j) {
|
||||
return ((i > j) ? (i - j) : (j - i)) <= 2;
|
||||
};
|
||||
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
|
||||
if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1]))
|
||||
auto_corr[inv_lag] =
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz);
|
||||
}
|
||||
// Find best pitch at 24 kHz.
|
||||
const auto pitch_candidates_inv_lags = FindBestPitchPeriods(
|
||||
{auto_corr.data(), auto_corr.size()},
|
||||
{pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz);
|
||||
const auto inv_lag = pitch_candidates_inv_lags[0]; // Refine the best.
|
||||
// Pseudo-interpolation.
|
||||
return PitchPseudoInterpolationInvLagAutoCorr(inv_lag, auto_corr);
|
||||
}
|
||||
|
||||
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo prev_pitch_48kHz) {
|
||||
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
|
||||
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
|
||||
// Stores information for a refined pitch candidate.
|
||||
struct RefinedPitchCandidate {
|
||||
RefinedPitchCandidate() {}
|
||||
RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
|
||||
: period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
|
||||
int period_24kHz;
|
||||
// Pitch strength information.
|
||||
float gain;
|
||||
// Additional pitch strength information used for the final estimation of
|
||||
// pitch gain.
|
||||
float xy; // Cross-correlation.
|
||||
float yy; // Auto-correlation.
|
||||
};
|
||||
|
||||
// Initialize.
|
||||
std::array<float, kMaxPitch24kHz + 1> yy_values;
|
||||
ComputeSlidingFrameSquareEnergies(pitch_buf,
|
||||
{yy_values.data(), yy_values.size()});
|
||||
const float xx = yy_values[0];
|
||||
// Helper lambdas.
|
||||
const auto pitch_gain = [](float xy, float yy, float xx) {
|
||||
RTC_DCHECK_LE(0.f, xx * yy);
|
||||
return xy / std::sqrt(1.f + xx * yy);
|
||||
};
|
||||
// Initial pitch candidate gain.
|
||||
RefinedPitchCandidate best_pitch;
|
||||
best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
|
||||
static_cast<int>(kMaxPitch24kHz - 1));
|
||||
best_pitch.xy = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
|
||||
best_pitch.yy = yy_values[best_pitch.period_24kHz];
|
||||
best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx);
|
||||
|
||||
// Store the initial pitch period information.
|
||||
const size_t initial_pitch_period = best_pitch.period_24kHz;
|
||||
const float initial_pitch_gain = best_pitch.gain;
|
||||
|
||||
// Given the initial pitch estimation, check lower periods (i.e., harmonics).
|
||||
const auto alternative_period = [](int period, int k, int n) -> int {
|
||||
RTC_DCHECK_GT(k, 0);
|
||||
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
|
||||
};
|
||||
for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
|
||||
++k) {
|
||||
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
|
||||
if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
|
||||
break;
|
||||
}
|
||||
// When looking at |candidate_pitch_period|, we also look at one of its
|
||||
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
|
||||
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
|
||||
// be greater than the maximum pitch period.
|
||||
int candidate_pitch_secondary_period = alternative_period(
|
||||
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
|
||||
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
|
||||
if (k == 2 &&
|
||||
candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
|
||||
candidate_pitch_secondary_period = initial_pitch_period;
|
||||
}
|
||||
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
|
||||
<< "The lower pitch period and the additional sub-harmonic must not "
|
||||
"coincide.";
|
||||
// Compute an auto-correlation score for the primary pitch candidate
|
||||
// |candidate_pitch_period| by also looking at its possible sub-harmonic
|
||||
// |candidate_pitch_secondary_period|.
|
||||
float xy_primary_period = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz);
|
||||
float xy_secondary_period = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, GetInvertedLag(candidate_pitch_secondary_period),
|
||||
kMaxPitch24kHz);
|
||||
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
|
||||
float yy = 0.5f * (yy_values[candidate_pitch_period] +
|
||||
yy_values[candidate_pitch_secondary_period]);
|
||||
float candidate_pitch_gain = pitch_gain(xy, yy, xx);
|
||||
|
||||
// Maybe update best period.
|
||||
float threshold = ComputePitchGainThreshold(
|
||||
candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain,
|
||||
prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain);
|
||||
if (candidate_pitch_gain > threshold) {
|
||||
best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy};
|
||||
}
|
||||
}
|
||||
|
||||
// Final pitch gain and period.
|
||||
best_pitch.xy = std::max(0.f, best_pitch.xy);
|
||||
RTC_DCHECK_LE(0.f, best_pitch.yy);
|
||||
float final_pitch_gain = (best_pitch.yy <= best_pitch.xy)
|
||||
? 1.f
|
||||
: best_pitch.xy / (best_pitch.yy + 1.f);
|
||||
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
|
||||
int final_pitch_period_48kHz = std::max(
|
||||
kMinPitch48kHz,
|
||||
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
|
||||
|
||||
return {final_pitch_period_48kHz, final_pitch_gain};
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Performs 2x decimation without any anti-aliasing filter.
|
||||
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
rtc::ArrayView<float, kBufSize12kHz> dst);
|
||||
|
||||
// Computes a gain threshold for a candidate pitch period given the initial and
|
||||
// the previous pitch period and gain estimates and the pitch period ratio used
|
||||
// to derive the candidate pitch period from the initial period.
|
||||
float ComputePitchGainThreshold(int candidate_pitch_period,
|
||||
int pitch_period_ratio,
|
||||
int initial_pitch_period,
|
||||
float initial_pitch_gain,
|
||||
int prev_pitch_period,
|
||||
float prev_pitch_gain);
|
||||
|
||||
// Computes the sum of squared samples for every sliding frame in the pitch
|
||||
// buffer. |yy_values| indexes are lags.
|
||||
//
|
||||
// The pitch buffer is structured as depicted below:
|
||||
// |.........|...........|
|
||||
// a b
|
||||
// The part on the left, named "a" contains the oldest samples, whereas "b" the
|
||||
// most recent ones. The size of "a" corresponds to the maximum pitch period,
|
||||
// that of "b" to the frame size (e.g., 16 ms and 20 ms respectively).
|
||||
void ComputeSlidingFrameSquareEnergies(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values);
|
||||
|
||||
// Given the auto-correlation coefficients stored according to
|
||||
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
|
||||
// and the second best pitch periods.
|
||||
std::array<size_t, 2> FindBestPitchPeriods(
|
||||
rtc::ArrayView<const float> auto_corr,
|
||||
rtc::ArrayView<const float> pitch_buf,
|
||||
size_t max_pitch_period);
|
||||
|
||||
// Refines the pitch period estimation given the pitch buffer |pitch_buf| and
|
||||
// the initial pitch period estimation |inv_lags|. Returns an inverted lag at
|
||||
// 48 kHz.
|
||||
size_t RefinePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
rtc::ArrayView<const size_t, 2> inv_lags);
|
||||
|
||||
// Refines the pitch period estimation and compute the pitch gain. Returns the
|
||||
// refined pitch estimation data at 48 kHz.
|
||||
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo prev_pitch_48kHz);
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
|
66
webrtc/modules/audio_processing/agc2/rnn_vad/ring_buffer.h
Normal file
66
webrtc/modules/audio_processing/agc2/rnn_vad/ring_buffer.h
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
|
||||
#include "api/array_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Ring buffer for N arrays of type T each one with size S.
|
||||
template <typename T, size_t S, size_t N>
|
||||
class RingBuffer {
|
||||
static_assert(S > 0, "");
|
||||
static_assert(N > 0, "");
|
||||
static_assert(std::is_arithmetic<T>::value,
|
||||
"Integral or floating point required.");
|
||||
|
||||
public:
|
||||
RingBuffer() : tail_(0) {}
|
||||
RingBuffer(const RingBuffer&) = delete;
|
||||
RingBuffer& operator=(const RingBuffer&) = delete;
|
||||
~RingBuffer() = default;
|
||||
// Set the ring buffer values to zero.
|
||||
void Reset() { buffer_.fill(0); }
|
||||
// Replace the least recently pushed array in the buffer with |new_values|.
|
||||
void Push(rtc::ArrayView<const T, S> new_values) {
|
||||
std::memcpy(buffer_.data() + S * tail_, new_values.data(), S * sizeof(T));
|
||||
tail_ += 1;
|
||||
if (tail_ == N)
|
||||
tail_ = 0;
|
||||
}
|
||||
// Return an array view onto the array with a given delay. A view on the last
|
||||
// and least recently push array is returned when |delay| is 0 and N - 1
|
||||
// respectively.
|
||||
rtc::ArrayView<const T, S> GetArrayView(size_t delay) const {
|
||||
const int delay_int = static_cast<int>(delay);
|
||||
RTC_DCHECK_LE(0, delay_int);
|
||||
RTC_DCHECK_LT(delay_int, N);
|
||||
int offset = tail_ - 1 - delay_int;
|
||||
if (offset < 0)
|
||||
offset += N;
|
||||
return {buffer_.data() + S * offset, S};
|
||||
}
|
||||
|
||||
private:
|
||||
int tail_; // Index of the least recently pushed sub-array.
|
||||
std::array<T, S * N> buffer_{};
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
|
425
webrtc/modules/audio_processing/agc2/rnn_vad/rnn.cc
Normal file
425
webrtc/modules/audio_processing/agc2/rnn_vad/rnn.cc
Normal file
@ -0,0 +1,425 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
|
||||
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
#include <emmintrin.h>
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
using rnnoise::kWeightsScale;
|
||||
|
||||
using rnnoise::kInputLayerInputSize;
|
||||
static_assert(kFeatureVectorSize == kInputLayerInputSize, "");
|
||||
using rnnoise::kInputDenseBias;
|
||||
using rnnoise::kInputDenseWeights;
|
||||
using rnnoise::kInputLayerOutputSize;
|
||||
static_assert(kInputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
|
||||
"Increase kFullyConnectedLayersMaxUnits.");
|
||||
|
||||
using rnnoise::kHiddenGruBias;
|
||||
using rnnoise::kHiddenGruRecurrentWeights;
|
||||
using rnnoise::kHiddenGruWeights;
|
||||
using rnnoise::kHiddenLayerOutputSize;
|
||||
static_assert(kHiddenLayerOutputSize <= kRecurrentLayersMaxUnits,
|
||||
"Increase kRecurrentLayersMaxUnits.");
|
||||
|
||||
using rnnoise::kOutputDenseBias;
|
||||
using rnnoise::kOutputDenseWeights;
|
||||
using rnnoise::kOutputLayerOutputSize;
|
||||
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
|
||||
"Increase kFullyConnectedLayersMaxUnits.");
|
||||
|
||||
using rnnoise::SigmoidApproximated;
|
||||
using rnnoise::TansigApproximated;
|
||||
|
||||
inline float RectifiedLinearUnit(float x) {
|
||||
return x < 0.f ? 0.f : x;
|
||||
}
|
||||
|
||||
std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
|
||||
std::vector<float> scaled_params(params.size());
|
||||
std::transform(params.begin(), params.end(), scaled_params.begin(),
|
||||
[](int8_t x) -> float {
|
||||
return rnnoise::kWeightsScale * static_cast<float>(x);
|
||||
});
|
||||
return scaled_params;
|
||||
}
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this
|
||||
// function to improve setup time.
|
||||
// Casts and scales |weights| and re-arranges the layout.
|
||||
std::vector<float> GetPreprocessedFcWeights(
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
size_t output_size) {
|
||||
if (output_size == 1) {
|
||||
return GetScaledParams(weights);
|
||||
}
|
||||
// Transpose, scale and cast.
|
||||
const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size);
|
||||
std::vector<float> w(weights.size());
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
w[o * input_size + i] = rnnoise::kWeightsScale *
|
||||
static_cast<float>(weights[i * output_size + o]);
|
||||
}
|
||||
}
|
||||
return w;
|
||||
}
|
||||
|
||||
constexpr size_t kNumGruGates = 3; // Update, reset, output.
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this
|
||||
// function to improve setup time.
|
||||
// Casts and scales |tensor_src| for a GRU layer and re-arranges the layout.
|
||||
// It works both for weights, recurrent weights and bias.
|
||||
std::vector<float> GetPreprocessedGruTensor(
|
||||
rtc::ArrayView<const int8_t> tensor_src,
|
||||
size_t output_size) {
|
||||
// Transpose, cast and scale.
|
||||
// |n| is the size of the first dimension of the 3-dim tensor |weights|.
|
||||
const size_t n =
|
||||
rtc::CheckedDivExact(tensor_src.size(), output_size * kNumGruGates);
|
||||
const size_t stride_src = kNumGruGates * output_size;
|
||||
const size_t stride_dst = n * output_size;
|
||||
std::vector<float> tensor_dst(tensor_src.size());
|
||||
for (size_t g = 0; g < kNumGruGates; ++g) {
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
tensor_dst[g * stride_dst + o * n + i] =
|
||||
rnnoise::kWeightsScale *
|
||||
static_cast<float>(
|
||||
tensor_src[i * stride_src + g * output_size + o]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tensor_dst;
|
||||
}
|
||||
|
||||
void ComputeGruUpdateResetGates(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> state,
|
||||
rtc::ArrayView<float> gate) {
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
gate[o] = bias[o];
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
gate[o] += input[i] * weights[o * input_size + i];
|
||||
}
|
||||
for (size_t s = 0; s < output_size; ++s) {
|
||||
gate[o] += state[s] * recurrent_weights[o * output_size + s];
|
||||
}
|
||||
gate[o] = SigmoidApproximated(gate[o]);
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeGruOutputGate(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> state,
|
||||
rtc::ArrayView<const float> reset,
|
||||
rtc::ArrayView<float> gate) {
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
gate[o] = bias[o];
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
gate[o] += input[i] * weights[o * input_size + i];
|
||||
}
|
||||
for (size_t s = 0; s < output_size; ++s) {
|
||||
gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
|
||||
}
|
||||
gate[o] = RectifiedLinearUnit(gate[o]);
|
||||
}
|
||||
}
|
||||
|
||||
// Gated recurrent unit (GRU) layer un-optimized implementation.
|
||||
void ComputeGruLayerOutput(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<float> state) {
|
||||
RTC_DCHECK_EQ(input_size, input.size());
|
||||
// Stride and offset used to read parameter arrays.
|
||||
const size_t stride_in = input_size * output_size;
|
||||
const size_t stride_out = output_size * output_size;
|
||||
|
||||
// Update gate.
|
||||
std::array<float, kRecurrentLayersMaxUnits> update;
|
||||
ComputeGruUpdateResetGates(
|
||||
input_size, output_size, weights.subview(0, stride_in),
|
||||
recurrent_weights.subview(0, stride_out), bias.subview(0, output_size),
|
||||
input, state, update);
|
||||
|
||||
// Reset gate.
|
||||
std::array<float, kRecurrentLayersMaxUnits> reset;
|
||||
ComputeGruUpdateResetGates(
|
||||
input_size, output_size, weights.subview(stride_in, stride_in),
|
||||
recurrent_weights.subview(stride_out, stride_out),
|
||||
bias.subview(output_size, output_size), input, state, reset);
|
||||
|
||||
// Output gate.
|
||||
std::array<float, kRecurrentLayersMaxUnits> output;
|
||||
ComputeGruOutputGate(
|
||||
input_size, output_size, weights.subview(2 * stride_in, stride_in),
|
||||
recurrent_weights.subview(2 * stride_out, stride_out),
|
||||
bias.subview(2 * output_size, output_size), input, state, reset, output);
|
||||
|
||||
// Update output through the update gates and update the state.
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
output[o] = update[o] * state[o] + (1.f - update[o]) * output[o];
|
||||
state[o] = output[o];
|
||||
}
|
||||
}
|
||||
|
||||
// Fully connected layer un-optimized implementation.
|
||||
void ComputeFullyConnectedLayerOutput(
|
||||
size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
rtc::ArrayView<float> output) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
output[o] = bias[o];
|
||||
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
|
||||
// |weights_| change the performance across different platforms.
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
output[o] += input[i] * weights[o * input_size + i];
|
||||
}
|
||||
output[o] = activation_function(output[o]);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
// Fully connected layer SSE2 implementation.
|
||||
void ComputeFullyConnectedLayerOutputSse2(
|
||||
size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
rtc::ArrayView<float> output) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
const size_t input_size_by_4 = input_size >> 2;
|
||||
const size_t offset = input_size & ~3;
|
||||
__m128 sum_wx_128;
|
||||
const float* v = reinterpret_cast<const float*>(&sum_wx_128);
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
// Perform 128 bit vector operations.
|
||||
sum_wx_128 = _mm_set1_ps(0);
|
||||
const float* x_p = input.data();
|
||||
const float* w_p = weights.data() + o * input_size;
|
||||
for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
|
||||
sum_wx_128 = _mm_add_ps(sum_wx_128,
|
||||
_mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p)));
|
||||
}
|
||||
// Perform non-vector operations for any remaining items, sum up bias term
|
||||
// and results from the vectorized code, and apply the activation function.
|
||||
output[o] = activation_function(
|
||||
std::inner_product(input.begin() + offset, input.end(),
|
||||
weights.begin() + o * input_size + offset,
|
||||
bias[o] + v[0] + v[1] + v[2] + v[3]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
|
||||
FullyConnectedLayer::FullyConnectedLayer(
|
||||
const size_t input_size,
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
Optimization optimization)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetScaledParams(bias)),
|
||||
weights_(GetPreprocessedFcWeights(weights, output_size)),
|
||||
activation_function_(activation_function),
|
||||
optimization_(optimization) {
|
||||
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
|
||||
<< "Static over-allocation of fully-connected layers output vectors is "
|
||||
"not sufficient.";
|
||||
RTC_DCHECK_EQ(output_size_, bias_.size())
|
||||
<< "Mismatching output size and bias terms array size.";
|
||||
RTC_DCHECK_EQ(input_size_ * output_size_, weights_.size())
|
||||
<< "Mismatching input-output size and weight coefficients array size.";
|
||||
}
|
||||
|
||||
FullyConnectedLayer::~FullyConnectedLayer() = default;
|
||||
|
||||
rtc::ArrayView<const float> FullyConnectedLayer::GetOutput() const {
|
||||
return rtc::ArrayView<const float>(output_.data(), output_size_);
|
||||
}
|
||||
|
||||
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input,
|
||||
bias_, weights_,
|
||||
activation_function_, output_);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
|
||||
weights_, activation_function_, output_);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
|
||||
weights_, activation_function_, output_);
|
||||
}
|
||||
}
|
||||
|
||||
GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
const size_t input_size,
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetPreprocessedGruTensor(bias, output_size)),
|
||||
weights_(GetPreprocessedGruTensor(weights, output_size)),
|
||||
recurrent_weights_(
|
||||
GetPreprocessedGruTensor(recurrent_weights, output_size)),
|
||||
optimization_(optimization) {
|
||||
RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
|
||||
<< "Static over-allocation of recurrent layers state vectors is not "
|
||||
"sufficient.";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
|
||||
<< "Mismatching output size and bias terms array size.";
|
||||
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
|
||||
<< "Mismatching input-output size and weight coefficients array size.";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
|
||||
recurrent_weights_.size())
|
||||
<< "Mismatching input-output size and recurrent weight coefficients array"
|
||||
" size.";
|
||||
Reset();
|
||||
}
|
||||
|
||||
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
|
||||
|
||||
rtc::ArrayView<const float> GatedRecurrentLayer::GetOutput() const {
|
||||
return rtc::ArrayView<const float>(state_.data(), output_size_);
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::Reset() {
|
||||
state_.fill(0.f);
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
}
|
||||
}
|
||||
|
||||
RnnBasedVad::RnnBasedVad()
|
||||
: input_layer_(kInputLayerInputSize,
|
||||
kInputLayerOutputSize,
|
||||
kInputDenseBias,
|
||||
kInputDenseWeights,
|
||||
TansigApproximated,
|
||||
DetectOptimization()),
|
||||
hidden_layer_(kInputLayerOutputSize,
|
||||
kHiddenLayerOutputSize,
|
||||
kHiddenGruBias,
|
||||
kHiddenGruWeights,
|
||||
kHiddenGruRecurrentWeights,
|
||||
DetectOptimization()),
|
||||
output_layer_(kHiddenLayerOutputSize,
|
||||
kOutputLayerOutputSize,
|
||||
kOutputDenseBias,
|
||||
kOutputDenseWeights,
|
||||
SigmoidApproximated,
|
||||
DetectOptimization()) {
|
||||
// Input-output chaining size checks.
|
||||
RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size())
|
||||
<< "The input and the hidden layers sizes do not match.";
|
||||
RTC_DCHECK_EQ(hidden_layer_.output_size(), output_layer_.input_size())
|
||||
<< "The hidden and the output layers sizes do not match.";
|
||||
}
|
||||
|
||||
RnnBasedVad::~RnnBasedVad() = default;
|
||||
|
||||
void RnnBasedVad::Reset() {
|
||||
hidden_layer_.Reset();
|
||||
}
|
||||
|
||||
float RnnBasedVad::ComputeVadProbability(
|
||||
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
|
||||
bool is_silence) {
|
||||
if (is_silence) {
|
||||
Reset();
|
||||
return 0.f;
|
||||
}
|
||||
input_layer_.ComputeOutput(feature_vector);
|
||||
hidden_layer_.ComputeOutput(input_layer_.GetOutput());
|
||||
output_layer_.ComputeOutput(hidden_layer_.GetOutput());
|
||||
const auto vad_output = output_layer_.GetOutput();
|
||||
return vad_output[0];
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
126
webrtc/modules/audio_processing/agc2/rnn_vad/rnn.h
Normal file
126
webrtc/modules/audio_processing/agc2/rnn_vad/rnn.h
Normal file
@ -0,0 +1,126 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "api/function_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Maximum number of units for a fully-connected layer. This value is used to
|
||||
// over-allocate space for fully-connected layers output vectors (implemented as
|
||||
// std::array). The value should equal the number of units of the largest
|
||||
// fully-connected layer.
|
||||
constexpr size_t kFullyConnectedLayersMaxUnits = 24;
|
||||
|
||||
// Maximum number of units for a recurrent layer. This value is used to
|
||||
// over-allocate space for recurrent layers state vectors (implemented as
|
||||
// std::array). The value should equal the number of units of the largest
|
||||
// recurrent layer.
|
||||
constexpr size_t kRecurrentLayersMaxUnits = 24;
|
||||
|
||||
// Fully-connected layer.
|
||||
class FullyConnectedLayer {
|
||||
public:
|
||||
FullyConnectedLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
Optimization optimization);
|
||||
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
||||
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
||||
~FullyConnectedLayer();
|
||||
size_t input_size() const { return input_size_; }
|
||||
size_t output_size() const { return output_size_; }
|
||||
Optimization optimization() const { return optimization_; }
|
||||
rtc::ArrayView<const float> GetOutput() const;
|
||||
// Computes the fully-connected layer output.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
rtc::FunctionView<float(float)> activation_function_;
|
||||
// The output vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, for efficiency, over-allocation is used.
|
||||
std::array<float, kFullyConnectedLayersMaxUnits> output_;
|
||||
const Optimization optimization_;
|
||||
};
|
||||
|
||||
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
|
||||
// activation functions for the update/reset and output gates respectively.
|
||||
class GatedRecurrentLayer {
|
||||
public:
|
||||
GatedRecurrentLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization);
|
||||
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
|
||||
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
|
||||
~GatedRecurrentLayer();
|
||||
size_t input_size() const { return input_size_; }
|
||||
size_t output_size() const { return output_size_; }
|
||||
Optimization optimization() const { return optimization_; }
|
||||
rtc::ArrayView<const float> GetOutput() const;
|
||||
void Reset();
|
||||
// Computes the recurrent layer output and updates the status.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
const std::vector<float> recurrent_weights_;
|
||||
// The state vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, to avoid dynamic allocation, over-allocation is used.
|
||||
std::array<float, kRecurrentLayersMaxUnits> state_;
|
||||
const Optimization optimization_;
|
||||
};
|
||||
|
||||
// Recurrent network based VAD.
|
||||
class RnnBasedVad {
|
||||
public:
|
||||
RnnBasedVad();
|
||||
RnnBasedVad(const RnnBasedVad&) = delete;
|
||||
RnnBasedVad& operator=(const RnnBasedVad&) = delete;
|
||||
~RnnBasedVad();
|
||||
void Reset();
|
||||
// Compute and returns the probability of voice (range: [0.0, 1.0]).
|
||||
float ComputeVadProbability(
|
||||
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
|
||||
bool is_silence);
|
||||
|
||||
private:
|
||||
FullyConnectedLayer input_layer_;
|
||||
GatedRecurrentLayer hidden_layer_;
|
||||
FullyConnectedLayer output_layer_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
|
120
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc
Normal file
120
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc
Normal file
@ -0,0 +1,120 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "common_audio/resampler/push_sinc_resampler.h"
|
||||
#include "common_audio/wav_file.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
#include "rtc_base/logging.h"
|
||||
|
||||
ABSL_FLAG(std::string, i, "", "Path to the input wav file");
|
||||
ABSL_FLAG(std::string, f, "", "Path to the output features file");
|
||||
ABSL_FLAG(std::string, o, "", "Path to the output VAD probabilities file");
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
rtc::LogMessage::LogToDebug(rtc::LS_INFO);
|
||||
|
||||
// Open wav input file and check properties.
|
||||
const std::string input_wav_file = absl::GetFlag(FLAGS_i);
|
||||
WavReader wav_reader(input_wav_file);
|
||||
if (wav_reader.num_channels() != 1) {
|
||||
RTC_LOG(LS_ERROR) << "Only mono wav files are supported";
|
||||
return 1;
|
||||
}
|
||||
if (wav_reader.sample_rate() % 100 != 0) {
|
||||
RTC_LOG(LS_ERROR) << "The sample rate rate must allow 10 ms frames.";
|
||||
return 1;
|
||||
}
|
||||
RTC_LOG(LS_INFO) << "Input sample rate: " << wav_reader.sample_rate();
|
||||
|
||||
// Init output files.
|
||||
const std::string output_vad_probs_file = absl::GetFlag(FLAGS_o);
|
||||
FILE* vad_probs_file = fopen(output_vad_probs_file.c_str(), "wb");
|
||||
FILE* features_file = nullptr;
|
||||
const std::string output_feature_file = absl::GetFlag(FLAGS_f);
|
||||
if (!output_feature_file.empty()) {
|
||||
features_file = fopen(output_feature_file.c_str(), "wb");
|
||||
}
|
||||
|
||||
// Initialize.
|
||||
const size_t frame_size_10ms =
|
||||
rtc::CheckedDivExact(wav_reader.sample_rate(), 100);
|
||||
std::vector<float> samples_10ms;
|
||||
samples_10ms.resize(frame_size_10ms);
|
||||
std::array<float, kFrameSize10ms24kHz> samples_10ms_24kHz;
|
||||
PushSincResampler resampler(frame_size_10ms, kFrameSize10ms24kHz);
|
||||
FeaturesExtractor features_extractor;
|
||||
std::array<float, kFeatureVectorSize> feature_vector;
|
||||
RnnBasedVad rnn_vad;
|
||||
|
||||
// Compute VAD probabilities.
|
||||
while (true) {
|
||||
// Read frame at the input sample rate.
|
||||
const auto read_samples =
|
||||
wav_reader.ReadSamples(frame_size_10ms, samples_10ms.data());
|
||||
if (read_samples < frame_size_10ms) {
|
||||
break; // EOF.
|
||||
}
|
||||
// Resample input.
|
||||
resampler.Resample(samples_10ms.data(), samples_10ms.size(),
|
||||
samples_10ms_24kHz.data(), samples_10ms_24kHz.size());
|
||||
|
||||
// Extract features and feed the RNN.
|
||||
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
|
||||
samples_10ms_24kHz, feature_vector);
|
||||
float vad_probability =
|
||||
rnn_vad.ComputeVadProbability(feature_vector, is_silence);
|
||||
// Write voice probability.
|
||||
RTC_DCHECK_GE(vad_probability, 0.f);
|
||||
RTC_DCHECK_GE(1.f, vad_probability);
|
||||
fwrite(&vad_probability, sizeof(float), 1, vad_probs_file);
|
||||
// Write features.
|
||||
if (features_file) {
|
||||
const float float_is_silence = is_silence ? 1.f : 0.f;
|
||||
fwrite(&float_is_silence, sizeof(float), 1, features_file);
|
||||
if (is_silence) {
|
||||
// Do not write uninitialized values.
|
||||
feature_vector.fill(0.f);
|
||||
}
|
||||
fwrite(feature_vector.data(), sizeof(float), kFeatureVectorSize,
|
||||
features_file);
|
||||
}
|
||||
}
|
||||
|
||||
// Close output file(s).
|
||||
fclose(vad_probs_file);
|
||||
RTC_LOG(LS_INFO) << "VAD probabilities written to " << output_vad_probs_file;
|
||||
if (features_file) {
|
||||
fclose(features_file);
|
||||
RTC_LOG(LS_INFO) << "features written to " << output_feature_file;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return webrtc::rnn_vad::test::main(argc, argv);
|
||||
}
|
@ -0,0 +1,79 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Linear buffer implementation to (i) push fixed size chunks of sequential data
|
||||
// and (ii) view contiguous parts of the buffer. The buffer and the pushed
|
||||
// chunks have size S and N respectively. For instance, when S = 2N the first
|
||||
// half of the sequence buffer is replaced with its second half, and the new N
|
||||
// values are written at the end of the buffer.
|
||||
// The class also provides a view on the most recent M values, where 0 < M <= S
|
||||
// and by default M = N.
|
||||
template <typename T, size_t S, size_t N, size_t M = N>
|
||||
class SequenceBuffer {
|
||||
static_assert(N <= S,
|
||||
"The new chunk size cannot be larger than the sequence buffer "
|
||||
"size.");
|
||||
static_assert(std::is_arithmetic<T>::value,
|
||||
"Integral or floating point required.");
|
||||
|
||||
public:
|
||||
SequenceBuffer() : buffer_(S) {
|
||||
RTC_DCHECK_EQ(S, buffer_.size());
|
||||
Reset();
|
||||
}
|
||||
SequenceBuffer(const SequenceBuffer&) = delete;
|
||||
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
|
||||
~SequenceBuffer() = default;
|
||||
size_t size() const { return S; }
|
||||
size_t chunks_size() const { return N; }
|
||||
// Sets the sequence buffer values to zero.
|
||||
void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); }
|
||||
// Returns a view on the whole buffer.
|
||||
rtc::ArrayView<const T, S> GetBufferView() const {
|
||||
return {buffer_.data(), S};
|
||||
}
|
||||
// Returns a view on the M most recent values of the buffer.
|
||||
rtc::ArrayView<const T, M> GetMostRecentValuesView() const {
|
||||
static_assert(M <= S,
|
||||
"The number of most recent values cannot be larger than the "
|
||||
"sequence buffer size.");
|
||||
return {buffer_.data() + S - M, M};
|
||||
}
|
||||
// Shifts left the buffer by N items and add new N items at the end.
|
||||
void Push(rtc::ArrayView<const T, N> new_values) {
|
||||
// Make space for the new values.
|
||||
if (S > N)
|
||||
std::memmove(buffer_.data(), buffer_.data() + N, (S - N) * sizeof(T));
|
||||
// Copy the new values at the end of the buffer.
|
||||
std::memcpy(buffer_.data() + S - N, new_values.data(), N * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<T> buffer_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
@ -0,0 +1,213 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
constexpr float kSilenceThreshold = 0.04f;
|
||||
|
||||
// Computes the new cepstral difference stats and pushes them into the passed
|
||||
// symmetric matrix buffer.
|
||||
void UpdateCepstralDifferenceStats(
|
||||
rtc::ArrayView<const float, kNumBands> new_cepstral_coeffs,
|
||||
const RingBuffer<float, kNumBands, kCepstralCoeffsHistorySize>& ring_buf,
|
||||
SymmetricMatrixBuffer<float, kCepstralCoeffsHistorySize>* sym_matrix_buf) {
|
||||
RTC_DCHECK(sym_matrix_buf);
|
||||
// Compute the new cepstral distance stats.
|
||||
std::array<float, kCepstralCoeffsHistorySize - 1> distances;
|
||||
for (size_t i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
|
||||
const size_t delay = i + 1;
|
||||
auto old_cepstral_coeffs = ring_buf.GetArrayView(delay);
|
||||
distances[i] = 0.f;
|
||||
for (size_t k = 0; k < kNumBands; ++k) {
|
||||
const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k];
|
||||
distances[i] += c * c;
|
||||
}
|
||||
}
|
||||
// Push the new spectral distance stats into the symmetric matrix buffer.
|
||||
sym_matrix_buf->Push(distances);
|
||||
}
|
||||
|
||||
// Computes the first half of the Vorbis window.
|
||||
std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
|
||||
float scaling = 1.f) {
|
||||
constexpr size_t kHalfSize = kFrameSize20ms24kHz / 2;
|
||||
std::array<float, kHalfSize> half_window{};
|
||||
for (size_t i = 0; i < kHalfSize; ++i) {
|
||||
half_window[i] =
|
||||
scaling *
|
||||
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
|
||||
std::sin(0.5 * kPi * (i + 0.5) / kHalfSize));
|
||||
}
|
||||
return half_window;
|
||||
}
|
||||
|
||||
// Computes the forward FFT on a 20 ms frame to which a given window function is
|
||||
// applied. The Fourier coefficient corresponding to the Nyquist frequency is
|
||||
// set to zero (it is never used and this allows to simplify the code).
|
||||
void ComputeWindowedForwardFft(
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> frame,
|
||||
const std::array<float, kFrameSize20ms24kHz / 2>& half_window,
|
||||
Pffft::FloatBuffer* fft_input_buffer,
|
||||
Pffft::FloatBuffer* fft_output_buffer,
|
||||
Pffft* fft) {
|
||||
RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
|
||||
// Apply windowing.
|
||||
auto in = fft_input_buffer->GetView();
|
||||
for (size_t i = 0, j = kFrameSize20ms24kHz - 1; i < half_window.size();
|
||||
++i, --j) {
|
||||
in[i] = frame[i] * half_window[i];
|
||||
in[j] = frame[j] * half_window[i];
|
||||
}
|
||||
fft->ForwardTransform(*fft_input_buffer, fft_output_buffer, /*ordered=*/true);
|
||||
// Set the Nyquist frequency coefficient to zero.
|
||||
auto out = fft_output_buffer->GetView();
|
||||
out[1] = 0.f;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SpectralFeaturesExtractor::SpectralFeaturesExtractor()
|
||||
: half_window_(ComputeScaledHalfVorbisWindow(
|
||||
1.f / static_cast<float>(kFrameSize20ms24kHz))),
|
||||
fft_(kFrameSize20ms24kHz, Pffft::FftType::kReal),
|
||||
fft_buffer_(fft_.CreateBuffer()),
|
||||
reference_frame_fft_(fft_.CreateBuffer()),
|
||||
lagged_frame_fft_(fft_.CreateBuffer()),
|
||||
dct_table_(ComputeDctTable()) {}
|
||||
|
||||
SpectralFeaturesExtractor::~SpectralFeaturesExtractor() = default;
|
||||
|
||||
void SpectralFeaturesExtractor::Reset() {
|
||||
cepstral_coeffs_ring_buf_.Reset();
|
||||
cepstral_diffs_buf_.Reset();
|
||||
}
|
||||
|
||||
bool SpectralFeaturesExtractor::CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
|
||||
rtc::ArrayView<float, kNumBands - kNumLowerBands> higher_bands_cepstrum,
|
||||
rtc::ArrayView<float, kNumLowerBands> average,
|
||||
rtc::ArrayView<float, kNumLowerBands> first_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> second_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
|
||||
float* variability) {
|
||||
// Compute the Opus band energies for the reference frame.
|
||||
ComputeWindowedForwardFft(reference_frame, half_window_, fft_buffer_.get(),
|
||||
reference_frame_fft_.get(), &fft_);
|
||||
spectral_correlator_.ComputeAutoCorrelation(
|
||||
reference_frame_fft_->GetConstView(), reference_frame_bands_energy_);
|
||||
// Check if the reference frame has silence.
|
||||
const float tot_energy =
|
||||
std::accumulate(reference_frame_bands_energy_.begin(),
|
||||
reference_frame_bands_energy_.end(), 0.f);
|
||||
if (tot_energy < kSilenceThreshold) {
|
||||
return true;
|
||||
}
|
||||
// Compute the Opus band energies for the lagged frame.
|
||||
ComputeWindowedForwardFft(lagged_frame, half_window_, fft_buffer_.get(),
|
||||
lagged_frame_fft_.get(), &fft_);
|
||||
spectral_correlator_.ComputeAutoCorrelation(lagged_frame_fft_->GetConstView(),
|
||||
lagged_frame_bands_energy_);
|
||||
// Log of the band energies for the reference frame.
|
||||
std::array<float, kNumBands> log_bands_energy;
|
||||
ComputeSmoothedLogMagnitudeSpectrum(reference_frame_bands_energy_,
|
||||
log_bands_energy);
|
||||
// Reference frame cepstrum.
|
||||
std::array<float, kNumBands> cepstrum;
|
||||
ComputeDct(log_bands_energy, dct_table_, cepstrum);
|
||||
// Ad-hoc correction terms for the first two cepstral coefficients.
|
||||
cepstrum[0] -= 12.f;
|
||||
cepstrum[1] -= 4.f;
|
||||
// Update the ring buffer and the cepstral difference stats.
|
||||
cepstral_coeffs_ring_buf_.Push(cepstrum);
|
||||
UpdateCepstralDifferenceStats(cepstrum, cepstral_coeffs_ring_buf_,
|
||||
&cepstral_diffs_buf_);
|
||||
// Write the higher bands cepstral coefficients.
|
||||
RTC_DCHECK_EQ(cepstrum.size() - kNumLowerBands, higher_bands_cepstrum.size());
|
||||
std::copy(cepstrum.begin() + kNumLowerBands, cepstrum.end(),
|
||||
higher_bands_cepstrum.begin());
|
||||
// Compute and write remaining features.
|
||||
ComputeAvgAndDerivatives(average, first_derivative, second_derivative);
|
||||
ComputeNormalizedCepstralCorrelation(bands_cross_corr);
|
||||
RTC_DCHECK(variability);
|
||||
*variability = ComputeVariability();
|
||||
return false;
|
||||
}
|
||||
|
||||
void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
|
||||
rtc::ArrayView<float, kNumLowerBands> average,
|
||||
rtc::ArrayView<float, kNumLowerBands> first_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> second_derivative) const {
|
||||
auto curr = cepstral_coeffs_ring_buf_.GetArrayView(0);
|
||||
auto prev1 = cepstral_coeffs_ring_buf_.GetArrayView(1);
|
||||
auto prev2 = cepstral_coeffs_ring_buf_.GetArrayView(2);
|
||||
RTC_DCHECK_EQ(average.size(), first_derivative.size());
|
||||
RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
|
||||
RTC_DCHECK_LE(average.size(), curr.size());
|
||||
for (size_t i = 0; i < average.size(); ++i) {
|
||||
// Average, kernel: [1, 1, 1].
|
||||
average[i] = curr[i] + prev1[i] + prev2[i];
|
||||
// First derivative, kernel: [1, 0, - 1].
|
||||
first_derivative[i] = curr[i] - prev2[i];
|
||||
// Second derivative, Laplacian kernel: [1, -2, 1].
|
||||
second_derivative[i] = curr[i] - 2 * prev1[i] + prev2[i];
|
||||
}
|
||||
}
|
||||
|
||||
void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
|
||||
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr) {
|
||||
spectral_correlator_.ComputeCrossCorrelation(
|
||||
reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
|
||||
bands_cross_corr_);
|
||||
// Normalize.
|
||||
for (size_t i = 0; i < bands_cross_corr_.size(); ++i) {
|
||||
bands_cross_corr_[i] =
|
||||
bands_cross_corr_[i] /
|
||||
std::sqrt(0.001f + reference_frame_bands_energy_[i] *
|
||||
lagged_frame_bands_energy_[i]);
|
||||
}
|
||||
// Cepstrum.
|
||||
ComputeDct(bands_cross_corr_, dct_table_, bands_cross_corr);
|
||||
// Ad-hoc correction terms for the first two cepstral coefficients.
|
||||
bands_cross_corr[0] -= 1.3f;
|
||||
bands_cross_corr[1] -= 0.9f;
|
||||
}
|
||||
|
||||
float SpectralFeaturesExtractor::ComputeVariability() const {
|
||||
// Compute cepstral variability score.
|
||||
float variability = 0.f;
|
||||
for (size_t delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
|
||||
float min_dist = std::numeric_limits<float>::max();
|
||||
for (size_t delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
|
||||
if (delay1 == delay2) // The distance would be 0.
|
||||
continue;
|
||||
min_dist =
|
||||
std::min(min_dist, cepstral_diffs_buf_.GetValue(delay1, delay2));
|
||||
}
|
||||
variability += min_dist;
|
||||
}
|
||||
// Normalize (based on training set stats).
|
||||
// TODO(bugs.webrtc.org/10480): Isolate normalization from feature extraction.
|
||||
return variability / kCepstralCoeffsHistorySize - 2.1f;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
@ -0,0 +1,79 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h"
|
||||
#include "modules/audio_processing/utility/pffft_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Class to compute spectral features.
|
||||
class SpectralFeaturesExtractor {
|
||||
public:
|
||||
SpectralFeaturesExtractor();
|
||||
SpectralFeaturesExtractor(const SpectralFeaturesExtractor&) = delete;
|
||||
SpectralFeaturesExtractor& operator=(const SpectralFeaturesExtractor&) =
|
||||
delete;
|
||||
~SpectralFeaturesExtractor();
|
||||
// Resets the internal state of the feature extractor.
|
||||
void Reset();
|
||||
// Analyzes a pair of reference and lagged frames from the pitch buffer,
|
||||
// detects silence and computes features. If silence is detected, the output
|
||||
// is neither computed nor written.
|
||||
bool CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
|
||||
rtc::ArrayView<float, kNumBands - kNumLowerBands> higher_bands_cepstrum,
|
||||
rtc::ArrayView<float, kNumLowerBands> average,
|
||||
rtc::ArrayView<float, kNumLowerBands> first_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> second_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
|
||||
float* variability);
|
||||
|
||||
private:
|
||||
void ComputeAvgAndDerivatives(
|
||||
rtc::ArrayView<float, kNumLowerBands> average,
|
||||
rtc::ArrayView<float, kNumLowerBands> first_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> second_derivative) const;
|
||||
void ComputeNormalizedCepstralCorrelation(
|
||||
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr);
|
||||
float ComputeVariability() const;
|
||||
|
||||
const std::array<float, kFrameSize20ms24kHz / 2> half_window_;
|
||||
Pffft fft_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> fft_buffer_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> reference_frame_fft_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> lagged_frame_fft_;
|
||||
SpectralCorrelator spectral_correlator_;
|
||||
std::array<float, kOpusBands24kHz> reference_frame_bands_energy_;
|
||||
std::array<float, kOpusBands24kHz> lagged_frame_bands_energy_;
|
||||
std::array<float, kOpusBands24kHz> bands_cross_corr_;
|
||||
const std::array<float, kNumBands * kNumBands> dct_table_;
|
||||
RingBuffer<float, kNumBands, kCepstralCoeffsHistorySize>
|
||||
cepstral_coeffs_ring_buf_;
|
||||
SymmetricMatrixBuffer<float, kCepstralCoeffsHistorySize> cepstral_diffs_buf_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
|
@ -0,0 +1,187 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Weights for each FFT coefficient for each Opus band (Nyquist frequency
|
||||
// excluded). The size of each band is specified in
|
||||
// |kOpusScaleNumBins24kHz20ms|.
|
||||
constexpr std::array<float, kFrameSize20ms24kHz / 2> kOpusBandWeights24kHz20ms =
|
||||
{{
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 0
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 1
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 2
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 3
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 4
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 5
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 6
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 7
|
||||
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
|
||||
0.625f, 0.75f, 0.875f, // Band 8
|
||||
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
|
||||
0.625f, 0.75f, 0.875f, // Band 9
|
||||
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
|
||||
0.625f, 0.75f, 0.875f, // Band 10
|
||||
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
|
||||
0.625f, 0.75f, 0.875f, // Band 11
|
||||
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
|
||||
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
|
||||
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
|
||||
0.9375f, // Band 12
|
||||
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
|
||||
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
|
||||
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
|
||||
0.9375f, // Band 13
|
||||
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
|
||||
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
|
||||
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
|
||||
0.9375f, // Band 14
|
||||
0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
|
||||
0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
|
||||
0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
|
||||
0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
|
||||
0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 15
|
||||
0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
|
||||
0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
|
||||
0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
|
||||
0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
|
||||
0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 16
|
||||
0.f, 0.03125f, 0.0625f, 0.09375f, 0.125f,
|
||||
0.15625f, 0.1875f, 0.21875f, 0.25f, 0.28125f,
|
||||
0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f,
|
||||
0.46875f, 0.5f, 0.53125f, 0.5625f, 0.59375f,
|
||||
0.625f, 0.65625f, 0.6875f, 0.71875f, 0.75f,
|
||||
0.78125f, 0.8125f, 0.84375f, 0.875f, 0.90625f,
|
||||
0.9375f, 0.96875f, // Band 17
|
||||
0.f, 0.0208333f, 0.0416667f, 0.0625f, 0.0833333f,
|
||||
0.104167f, 0.125f, 0.145833f, 0.166667f, 0.1875f,
|
||||
0.208333f, 0.229167f, 0.25f, 0.270833f, 0.291667f,
|
||||
0.3125f, 0.333333f, 0.354167f, 0.375f, 0.395833f,
|
||||
0.416667f, 0.4375f, 0.458333f, 0.479167f, 0.5f,
|
||||
0.520833f, 0.541667f, 0.5625f, 0.583333f, 0.604167f,
|
||||
0.625f, 0.645833f, 0.666667f, 0.6875f, 0.708333f,
|
||||
0.729167f, 0.75f, 0.770833f, 0.791667f, 0.8125f,
|
||||
0.833333f, 0.854167f, 0.875f, 0.895833f, 0.916667f,
|
||||
0.9375f, 0.958333f, 0.979167f // Band 18
|
||||
}};
|
||||
|
||||
} // namespace
|
||||
|
||||
SpectralCorrelator::SpectralCorrelator()
|
||||
: weights_(kOpusBandWeights24kHz20ms.begin(),
|
||||
kOpusBandWeights24kHz20ms.end()) {}
|
||||
|
||||
SpectralCorrelator::~SpectralCorrelator() = default;
|
||||
|
||||
void SpectralCorrelator::ComputeAutoCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const {
|
||||
ComputeCrossCorrelation(x, x, auto_corr);
|
||||
}
|
||||
|
||||
void SpectralCorrelator::ComputeCrossCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const {
|
||||
RTC_DCHECK_EQ(x.size(), kFrameSize20ms24kHz);
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
|
||||
RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
|
||||
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
|
||||
size_t k = 0; // Next Fourier coefficient index.
|
||||
cross_corr[0] = 0.f;
|
||||
for (size_t i = 0; i < kOpusBands24kHz - 1; ++i) {
|
||||
cross_corr[i + 1] = 0.f;
|
||||
for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
|
||||
const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
|
||||
const float tmp = weights_[k] * v;
|
||||
cross_corr[i] += v - tmp;
|
||||
cross_corr[i + 1] += tmp;
|
||||
k++;
|
||||
}
|
||||
}
|
||||
cross_corr[0] *= 2.f; // The first band only gets half contribution.
|
||||
RTC_DCHECK_EQ(k, kFrameSize20ms24kHz / 2); // Nyquist coefficient never used.
|
||||
}
|
||||
|
||||
void ComputeSmoothedLogMagnitudeSpectrum(
|
||||
rtc::ArrayView<const float> bands_energy,
|
||||
rtc::ArrayView<float, kNumBands> log_bands_energy) {
|
||||
RTC_DCHECK_LE(bands_energy.size(), kNumBands);
|
||||
constexpr float kOneByHundred = 1e-2f;
|
||||
constexpr float kLogOneByHundred = -2.f;
|
||||
// Init.
|
||||
float log_max = kLogOneByHundred;
|
||||
float follow = kLogOneByHundred;
|
||||
const auto smooth = [&log_max, &follow](float x) {
|
||||
x = std::max(log_max - 7.f, std::max(follow - 1.5f, x));
|
||||
log_max = std::max(log_max, x);
|
||||
follow = std::max(follow - 1.5f, x);
|
||||
return x;
|
||||
};
|
||||
// Smoothing over the bands for which the band energy is defined.
|
||||
for (size_t i = 0; i < bands_energy.size(); ++i) {
|
||||
log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
|
||||
}
|
||||
// Smoothing over the remaining bands (zero energy).
|
||||
for (size_t i = bands_energy.size(); i < kNumBands; ++i) {
|
||||
log_bands_energy[i] = smooth(kLogOneByHundred);
|
||||
}
|
||||
}
|
||||
|
||||
std::array<float, kNumBands * kNumBands> ComputeDctTable() {
|
||||
std::array<float, kNumBands * kNumBands> dct_table;
|
||||
const double k = std::sqrt(0.5);
|
||||
for (size_t i = 0; i < kNumBands; ++i) {
|
||||
for (size_t j = 0; j < kNumBands; ++j)
|
||||
dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
|
||||
dct_table[i * kNumBands] *= k;
|
||||
}
|
||||
return dct_table;
|
||||
}
|
||||
|
||||
void ComputeDct(rtc::ArrayView<const float> in,
|
||||
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
|
||||
rtc::ArrayView<float> out) {
|
||||
// DCT scaling factor - i.e., sqrt(2 / kNumBands).
|
||||
constexpr float kDctScalingFactor = 0.301511345f;
|
||||
constexpr float kDctScalingFactorError =
|
||||
kDctScalingFactor * kDctScalingFactor -
|
||||
2.f / static_cast<float>(kNumBands);
|
||||
static_assert(
|
||||
(kDctScalingFactorError >= 0.f && kDctScalingFactorError < 1e-1f) ||
|
||||
(kDctScalingFactorError < 0.f && kDctScalingFactorError > -1e-1f),
|
||||
"kNumBands changed and kDctScalingFactor has not been updated.");
|
||||
RTC_DCHECK_NE(in.data(), out.data()) << "In-place DCT is not supported.";
|
||||
RTC_DCHECK_LE(in.size(), kNumBands);
|
||||
RTC_DCHECK_LE(1, out.size());
|
||||
RTC_DCHECK_LE(out.size(), in.size());
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
out[i] = 0.f;
|
||||
for (size_t j = 0; j < in.size(); ++j) {
|
||||
out[i] += in[j] * dct_table[j * kNumBands + i];
|
||||
}
|
||||
// TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.
|
||||
out[i] *= kDctScalingFactor;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
@ -0,0 +1,100 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// At a sample rate of 24 kHz, the last 3 Opus bands are beyond the Nyquist
|
||||
// frequency. However, band #19 gets the contributions from band #18 because
|
||||
// of the symmetric triangular filter with peak response at 12 kHz.
|
||||
constexpr size_t kOpusBands24kHz = 20;
|
||||
static_assert(kOpusBands24kHz < kNumBands,
|
||||
"The number of bands at 24 kHz must be less than those defined "
|
||||
"in the Opus scale at 48 kHz.");
|
||||
|
||||
// Number of FFT frequency bins covered by each band in the Opus scale at a
|
||||
// sample rate of 24 kHz for 20 ms frames.
|
||||
// Declared here for unit testing.
|
||||
constexpr std::array<int, kOpusBands24kHz - 1> GetOpusScaleNumBins24kHz20ms() {
|
||||
return {4, 4, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, 16, 16, 16, 24, 24, 32, 48};
|
||||
}
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to a separate file.
|
||||
// Class to compute band-wise spectral features in the Opus perceptual scale
|
||||
// for 20 ms frames sampled at 24 kHz. The analysis methods apply triangular
|
||||
// filters with peak response at the each band boundary.
|
||||
class SpectralCorrelator {
|
||||
public:
|
||||
// Ctor.
|
||||
SpectralCorrelator();
|
||||
SpectralCorrelator(const SpectralCorrelator&) = delete;
|
||||
SpectralCorrelator& operator=(const SpectralCorrelator&) = delete;
|
||||
~SpectralCorrelator();
|
||||
|
||||
// Computes the band-wise spectral auto-correlations.
|
||||
// |x| must:
|
||||
// - have size equal to |kFrameSize20ms24kHz|;
|
||||
// - be encoded as vectors of interleaved real-complex FFT coefficients
|
||||
// where x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
|
||||
void ComputeAutoCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const;
|
||||
|
||||
// Computes the band-wise spectral cross-correlations.
|
||||
// |x| and |y| must:
|
||||
// - have size equal to |kFrameSize20ms24kHz|;
|
||||
// - be encoded as vectors of interleaved real-complex FFT coefficients where
|
||||
// x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
|
||||
void ComputeCrossCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const;
|
||||
|
||||
private:
|
||||
const std::vector<float> weights_; // Weights for each Fourier coefficient.
|
||||
};
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
|
||||
// spectral_features.cc. Given a vector of Opus-bands energy coefficients,
|
||||
// computes the log magnitude spectrum applying smoothing both over time and
|
||||
// over frequency. Declared here for unit testing.
|
||||
void ComputeSmoothedLogMagnitudeSpectrum(
|
||||
rtc::ArrayView<const float> bands_energy,
|
||||
rtc::ArrayView<float, kNumBands> log_bands_energy);
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
|
||||
// spectral_features.cc. Creates a DCT table for arrays having size equal to
|
||||
// |kNumBands|. Declared here for unit testing.
|
||||
std::array<float, kNumBands * kNumBands> ComputeDctTable();
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
|
||||
// spectral_features.cc. Computes DCT for |in| given a pre-computed DCT table.
|
||||
// In-place computation is not allowed and |out| can be smaller than |in| in
|
||||
// order to only compute the first DCT coefficients. Declared here for unit
|
||||
// testing.
|
||||
void ComputeDct(rtc::ArrayView<const float> in,
|
||||
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
|
||||
rtc::ArrayView<float> out);
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
|
@ -0,0 +1,94 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <utility>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Data structure to buffer the results of pair-wise comparisons between items
|
||||
// stored in a ring buffer. Every time that the oldest item is replaced in the
|
||||
// ring buffer, the new one is compared to the remaining items in the ring
|
||||
// buffer. The results of such comparisons need to be buffered and automatically
|
||||
// removed when one of the two corresponding items that have been compared is
|
||||
// removed from the ring buffer. It is assumed that the comparison is symmetric
|
||||
// and that comparing an item with itself is not needed.
|
||||
template <typename T, size_t S>
|
||||
class SymmetricMatrixBuffer {
|
||||
static_assert(S > 2, "");
|
||||
|
||||
public:
|
||||
SymmetricMatrixBuffer() = default;
|
||||
SymmetricMatrixBuffer(const SymmetricMatrixBuffer&) = delete;
|
||||
SymmetricMatrixBuffer& operator=(const SymmetricMatrixBuffer&) = delete;
|
||||
~SymmetricMatrixBuffer() = default;
|
||||
// Sets the buffer values to zero.
|
||||
void Reset() {
|
||||
static_assert(std::is_arithmetic<T>::value,
|
||||
"Integral or floating point required.");
|
||||
buf_.fill(0);
|
||||
}
|
||||
// Pushes the results from the comparison between the most recent item and
|
||||
// those that are still in the ring buffer. The first element in |values| must
|
||||
// correspond to the comparison between the most recent item and the second
|
||||
// most recent one in the ring buffer, whereas the last element in |values|
|
||||
// must correspond to the comparison between the most recent item and the
|
||||
// oldest one in the ring buffer.
|
||||
void Push(rtc::ArrayView<T, S - 1> values) {
|
||||
// Move the lower-right sub-matrix of size (S-2) x (S-2) one row up and one
|
||||
// column left.
|
||||
std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T));
|
||||
// Copy new values in the last column in the right order.
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
const size_t index = (S - 1 - i) * (S - 1) - 1;
|
||||
RTC_DCHECK_LE(static_cast<size_t>(0), index);
|
||||
RTC_DCHECK_LT(index, buf_.size());
|
||||
buf_[index] = values[i];
|
||||
}
|
||||
}
|
||||
// Reads the value that corresponds to comparison of two items in the ring
|
||||
// buffer having delay |delay1| and |delay2|. The two arguments must not be
|
||||
// equal and both must be in {0, ..., S - 1}.
|
||||
T GetValue(size_t delay1, size_t delay2) const {
|
||||
int row = S - 1 - static_cast<int>(delay1);
|
||||
int col = S - 1 - static_cast<int>(delay2);
|
||||
RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed.";
|
||||
if (row > col)
|
||||
std::swap(row, col); // Swap to access the upper-right triangular part.
|
||||
RTC_DCHECK_LE(0, row);
|
||||
RTC_DCHECK_LT(row, S - 1) << "Not enforcing row < col and row != col.";
|
||||
RTC_DCHECK_LE(1, col) << "Not enforcing row < col and row != col.";
|
||||
RTC_DCHECK_LT(col, S);
|
||||
const int index = row * (S - 1) + (col - 1);
|
||||
RTC_DCHECK_LE(0, index);
|
||||
RTC_DCHECK_LT(index, buf_.size());
|
||||
return buf_[index];
|
||||
}
|
||||
|
||||
private:
|
||||
// Encode an upper-right triangular matrix (excluding its diagonal) using a
|
||||
// square matrix. This allows to move the data in Push() with one single
|
||||
// operation.
|
||||
std::array<T, (S - 1) * (S - 1)> buf_{};
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
|
129
webrtc/modules/audio_processing/agc2/rnn_vad/test_utils.cc
Normal file
129
webrtc/modules/audio_processing/agc2/rnn_vad/test_utils.cc
Normal file
@ -0,0 +1,129 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
#include "test/gtest.h"
|
||||
#include "test/testsupport/file_utils.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
using ReaderPairType =
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>;
|
||||
|
||||
} // namespace
|
||||
|
||||
using webrtc::test::ResourcePath;
|
||||
|
||||
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed) {
|
||||
ASSERT_EQ(expected.size(), computed.size());
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_FLOAT_EQ(expected[i], computed[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed,
|
||||
float tolerance) {
|
||||
ASSERT_EQ(expected.size(), computed.size());
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_NEAR(expected[i], computed[i], tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
|
||||
CreatePcmSamplesReader(const size_t frame_length) {
|
||||
auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
|
||||
test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
|
||||
frame_length);
|
||||
// The last incomplete frame is ignored.
|
||||
return {std::move(ptr), ptr->data_length() / frame_length};
|
||||
}
|
||||
|
||||
ReaderPairType CreatePitchBuffer24kHzReader() {
|
||||
constexpr size_t cols = 864;
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
|
||||
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
|
||||
}
|
||||
|
||||
ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
|
||||
constexpr size_t num_lp_residual_coeffs = 864;
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
|
||||
num_lp_residual_coeffs);
|
||||
return {std::move(ptr),
|
||||
rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
|
||||
}
|
||||
|
||||
ReaderPairType CreateVadProbsReader() {
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
|
||||
return {std::move(ptr), ptr->data_length()};
|
||||
}
|
||||
|
||||
PitchTestData::PitchTestData() {
|
||||
BinaryFileReader<float> test_data_reader(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
|
||||
static_cast<size_t>(1396));
|
||||
test_data_reader.ReadChunk(test_data_);
|
||||
}
|
||||
|
||||
PitchTestData::~PitchTestData() = default;
|
||||
|
||||
rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
|
||||
const {
|
||||
return {test_data_.data(), kBufSize24kHz};
|
||||
}
|
||||
|
||||
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
|
||||
PitchTestData::GetPitchBufSquareEnergiesView() const {
|
||||
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
|
||||
}
|
||||
|
||||
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
|
||||
PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
|
||||
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
|
||||
kNumPitchBufAutoCorrCoeffs};
|
||||
}
|
||||
|
||||
bool IsOptimizationAvailable(Optimization optimization) {
|
||||
switch (optimization) {
|
||||
case Optimization::kSse2:
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
return GetCPUInfo(kSSE2) != 0;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
case Optimization::kNeon:
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
case Optimization::kNone:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
161
webrtc/modules/audio_processing/agc2/rnn_vad/test_utils.h
Normal file
161
webrtc/modules/audio_processing/agc2/rnn_vad/test_utils.h
Normal file
@ -0,0 +1,161 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
|
||||
constexpr float kFloatMin = std::numeric_limits<float>::min();
|
||||
|
||||
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
|
||||
// that the values in the pair do not match.
|
||||
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed);
|
||||
|
||||
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
|
||||
// that their absolute error is above a given threshold.
|
||||
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed,
|
||||
float tolerance);
|
||||
|
||||
// Reader for binary files consisting of an arbitrary long sequence of elements
|
||||
// having type T. It is possible to read and cast to another type D at once.
|
||||
template <typename T, typename D = T>
|
||||
class BinaryFileReader {
|
||||
public:
|
||||
explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0)
|
||||
: is_(file_path, std::ios::binary | std::ios::ate),
|
||||
data_length_(is_.tellg() / sizeof(T)),
|
||||
chunk_size_(chunk_size) {
|
||||
RTC_CHECK(is_);
|
||||
SeekBeginning();
|
||||
buf_.resize(chunk_size_);
|
||||
}
|
||||
BinaryFileReader(const BinaryFileReader&) = delete;
|
||||
BinaryFileReader& operator=(const BinaryFileReader&) = delete;
|
||||
~BinaryFileReader() = default;
|
||||
size_t data_length() const { return data_length_; }
|
||||
bool ReadValue(D* dst) {
|
||||
if (std::is_same<T, D>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst), sizeof(T));
|
||||
} else {
|
||||
T v;
|
||||
is_.read(reinterpret_cast<char*>(&v), sizeof(T));
|
||||
*dst = static_cast<D>(v);
|
||||
}
|
||||
return is_.gcount() == sizeof(T);
|
||||
}
|
||||
// If |chunk_size| was specified in the ctor, it will check that the size of
|
||||
// |dst| equals |chunk_size|.
|
||||
bool ReadChunk(rtc::ArrayView<D> dst) {
|
||||
RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size()));
|
||||
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
|
||||
if (std::is_same<T, D>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
|
||||
} else {
|
||||
is_.read(reinterpret_cast<char*>(buf_.data()), bytes_to_read);
|
||||
std::transform(buf_.begin(), buf_.end(), dst.begin(),
|
||||
[](const T& v) -> D { return static_cast<D>(v); });
|
||||
}
|
||||
return is_.gcount() == bytes_to_read;
|
||||
}
|
||||
void SeekForward(size_t items) { is_.seekg(items * sizeof(T), is_.cur); }
|
||||
void SeekBeginning() { is_.seekg(0, is_.beg); }
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
const size_t data_length_;
|
||||
const size_t chunk_size_;
|
||||
std::vector<T> buf_;
|
||||
};
|
||||
|
||||
// Writer for binary files.
|
||||
template <typename T>
|
||||
class BinaryFileWriter {
|
||||
public:
|
||||
explicit BinaryFileWriter(const std::string& file_path)
|
||||
: os_(file_path, std::ios::binary) {}
|
||||
BinaryFileWriter(const BinaryFileWriter&) = delete;
|
||||
BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
|
||||
~BinaryFileWriter() = default;
|
||||
static_assert(std::is_arithmetic<T>::value, "");
|
||||
void WriteChunk(rtc::ArrayView<const T> value) {
|
||||
const std::streamsize bytes_to_write = value.size() * sizeof(T);
|
||||
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
|
||||
}
|
||||
|
||||
private:
|
||||
std::ofstream os_;
|
||||
};
|
||||
|
||||
// Factories for resource file readers.
|
||||
// The functions below return a pair where the first item is a reader unique
|
||||
// pointer and the second the number of chunks that can be read from the file.
|
||||
// Creates a reader for the PCM samples that casts from S16 to float and reads
|
||||
// chunks with length |frame_length|.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
|
||||
CreatePcmSamplesReader(const size_t frame_length);
|
||||
// Creates a reader for the pitch buffer content at 24 kHz.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||
CreatePitchBuffer24kHzReader();
|
||||
// Creates a reader for the the LP residual coefficients and the pitch period
|
||||
// and gain values.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||
CreateLpResidualAndPitchPeriodGainReader();
|
||||
// Creates a reader for the VAD probabilities.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||
CreateVadProbsReader();
|
||||
|
||||
constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
|
||||
constexpr size_t kNumPitchBufSquareEnergies = 385;
|
||||
constexpr size_t kPitchTestDataSize =
|
||||
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
|
||||
|
||||
// Class to retrieve a test pitch buffer content and the expected output for the
|
||||
// analysis steps.
|
||||
class PitchTestData {
|
||||
public:
|
||||
PitchTestData();
|
||||
~PitchTestData();
|
||||
rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
|
||||
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
|
||||
GetPitchBufSquareEnergiesView() const;
|
||||
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
|
||||
GetPitchBufAutoCorrCoeffsView() const;
|
||||
|
||||
private:
|
||||
std::array<float, kPitchTestDataSize> test_data_;
|
||||
};
|
||||
|
||||
// Returns true if the given optimization is available.
|
||||
bool IsOptimizationAvailable(Optimization optimization);
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
121
webrtc/modules/audio_processing/agc2/saturation_protector.cc
Normal file
121
webrtc/modules/audio_processing/agc2/saturation_protector.cc
Normal file
@ -0,0 +1,121 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/saturation_protector.h"
|
||||
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr float kMinLevelDbfs = -90.f;
|
||||
|
||||
// Min/max margins are based on speech crest-factor.
|
||||
constexpr float kMinMarginDb = 12.f;
|
||||
constexpr float kMaxMarginDb = 25.f;
|
||||
|
||||
using saturation_protector_impl::RingBuffer;
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RingBuffer::operator==(const RingBuffer& b) const {
|
||||
RTC_DCHECK_LE(size_, buffer_.size());
|
||||
RTC_DCHECK_LE(b.size_, b.buffer_.size());
|
||||
if (size_ != b.size_) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0, i0 = FrontIndex(), i1 = b.FrontIndex(); i < size_;
|
||||
++i, ++i0, ++i1) {
|
||||
if (buffer_[i0 % buffer_.size()] != b.buffer_[i1 % b.buffer_.size()]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RingBuffer::Reset() {
|
||||
next_ = 0;
|
||||
size_ = 0;
|
||||
}
|
||||
|
||||
void RingBuffer::PushBack(float v) {
|
||||
RTC_DCHECK_GE(next_, 0);
|
||||
RTC_DCHECK_GE(size_, 0);
|
||||
RTC_DCHECK_LT(next_, buffer_.size());
|
||||
RTC_DCHECK_LE(size_, buffer_.size());
|
||||
buffer_[next_++] = v;
|
||||
if (rtc::SafeEq(next_, buffer_.size())) {
|
||||
next_ = 0;
|
||||
}
|
||||
if (rtc::SafeLt(size_, buffer_.size())) {
|
||||
size_++;
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<float> RingBuffer::Front() const {
|
||||
if (size_ == 0) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
RTC_DCHECK_LT(FrontIndex(), buffer_.size());
|
||||
return buffer_[FrontIndex()];
|
||||
}
|
||||
|
||||
bool SaturationProtectorState::operator==(
|
||||
const SaturationProtectorState& b) const {
|
||||
return margin_db == b.margin_db && peak_delay_buffer == b.peak_delay_buffer &&
|
||||
max_peaks_dbfs == b.max_peaks_dbfs &&
|
||||
time_since_push_ms == b.time_since_push_ms;
|
||||
}
|
||||
|
||||
void ResetSaturationProtectorState(float initial_margin_db,
|
||||
SaturationProtectorState& state) {
|
||||
state.margin_db = initial_margin_db;
|
||||
state.peak_delay_buffer.Reset();
|
||||
state.max_peaks_dbfs = kMinLevelDbfs;
|
||||
state.time_since_push_ms = 0;
|
||||
}
|
||||
|
||||
void UpdateSaturationProtectorState(float speech_peak_dbfs,
|
||||
float speech_level_dbfs,
|
||||
SaturationProtectorState& state) {
|
||||
// Get the max peak over `kPeakEnveloperSuperFrameLengthMs` ms.
|
||||
state.max_peaks_dbfs = std::max(state.max_peaks_dbfs, speech_peak_dbfs);
|
||||
state.time_since_push_ms += kFrameDurationMs;
|
||||
if (rtc::SafeGt(state.time_since_push_ms, kPeakEnveloperSuperFrameLengthMs)) {
|
||||
// Push `max_peaks_dbfs` back into the ring buffer.
|
||||
state.peak_delay_buffer.PushBack(state.max_peaks_dbfs);
|
||||
// Reset.
|
||||
state.max_peaks_dbfs = kMinLevelDbfs;
|
||||
state.time_since_push_ms = 0;
|
||||
}
|
||||
|
||||
// Update margin by comparing the estimated speech level and the delayed max
|
||||
// speech peak power.
|
||||
// TODO(alessiob): Check with aleloi@ why we use a delay and how to tune it.
|
||||
const float delayed_peak_dbfs =
|
||||
state.peak_delay_buffer.Front().value_or(state.max_peaks_dbfs);
|
||||
const float difference_db = delayed_peak_dbfs - speech_level_dbfs;
|
||||
if (difference_db > state.margin_db) {
|
||||
// Attack.
|
||||
state.margin_db =
|
||||
state.margin_db * kSaturationProtectorAttackConstant +
|
||||
difference_db * (1.f - kSaturationProtectorAttackConstant);
|
||||
} else {
|
||||
// Decay.
|
||||
state.margin_db = state.margin_db * kSaturationProtectorDecayConstant +
|
||||
difference_db * (1.f - kSaturationProtectorDecayConstant);
|
||||
}
|
||||
|
||||
state.margin_db =
|
||||
rtc::SafeClamp<float>(state.margin_db, kMinMarginDb, kMaxMarginDb);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
82
webrtc/modules/audio_processing/agc2/saturation_protector.h
Normal file
82
webrtc/modules/audio_processing/agc2/saturation_protector.h
Normal file
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace saturation_protector_impl {
|
||||
|
||||
// Ring buffer which only supports (i) push back and (ii) read oldest item.
|
||||
class RingBuffer {
|
||||
public:
|
||||
bool operator==(const RingBuffer& b) const;
|
||||
inline bool operator!=(const RingBuffer& b) const { return !(*this == b); }
|
||||
|
||||
// Maximum number of values that the buffer can contain.
|
||||
int Capacity() const { return buffer_.size(); }
|
||||
// Number of values in the buffer.
|
||||
int Size() const { return size_; }
|
||||
|
||||
void Reset();
|
||||
// Pushes back `v`. If the buffer is full, the oldest value is replaced.
|
||||
void PushBack(float v);
|
||||
// Returns the oldest item in the buffer. Returns an empty value if the
|
||||
// buffer is empty.
|
||||
absl::optional<float> Front() const;
|
||||
|
||||
private:
|
||||
inline int FrontIndex() const {
|
||||
return rtc::SafeEq(size_, buffer_.size()) ? next_ : 0;
|
||||
}
|
||||
// `buffer_` has `size_` elements (up to the size of `buffer_`) and `next_` is
|
||||
// the position where the next new value is written in `buffer_`.
|
||||
std::array<float, kPeakEnveloperBufferSize> buffer_;
|
||||
int next_ = 0;
|
||||
int size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace saturation_protector_impl
|
||||
|
||||
// Saturation protector state. Exposed publicly for check-pointing and restore
|
||||
// ops.
|
||||
struct SaturationProtectorState {
|
||||
bool operator==(const SaturationProtectorState& s) const;
|
||||
inline bool operator!=(const SaturationProtectorState& s) const {
|
||||
return !(*this == s);
|
||||
}
|
||||
|
||||
float margin_db; // Recommended margin.
|
||||
saturation_protector_impl::RingBuffer peak_delay_buffer;
|
||||
float max_peaks_dbfs;
|
||||
int time_since_push_ms; // Time since the last ring buffer push operation.
|
||||
};
|
||||
|
||||
// Resets the saturation protector state.
|
||||
void ResetSaturationProtectorState(float initial_margin_db,
|
||||
SaturationProtectorState& state);
|
||||
|
||||
// Updates `state` by analyzing the estimated speech level `speech_level_dbfs`
|
||||
// and the peak power `speech_peak_dbfs` for an observed frame which is
|
||||
// reliably classified as "speech". `state` must not be modified without calling
|
||||
// this function.
|
||||
void UpdateSaturationProtectorState(float speech_peak_dbfs,
|
||||
float speech_level_dbfs,
|
||||
SaturationProtectorState& state);
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
|
177
webrtc/modules/audio_processing/agc2/signal_classifier.cc
Normal file
177
webrtc/modules/audio_processing/agc2/signal_classifier.cc
Normal file
@ -0,0 +1,177 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/signal_classifier.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/down_sampler.h"
|
||||
#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
bool IsSse2Available() {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
return GetCPUInfo(kSSE2) != 0;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
void RemoveDcLevel(rtc::ArrayView<float> x) {
|
||||
RTC_DCHECK_LT(0, x.size());
|
||||
float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f);
|
||||
mean /= x.size();
|
||||
|
||||
for (float& v : x) {
|
||||
v -= mean;
|
||||
}
|
||||
}
|
||||
|
||||
void PowerSpectrum(const OouraFft* ooura_fft,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> spectrum) {
|
||||
RTC_DCHECK_EQ(65, spectrum.size());
|
||||
RTC_DCHECK_EQ(128, x.size());
|
||||
float X[128];
|
||||
std::copy(x.data(), x.data() + x.size(), X);
|
||||
ooura_fft->Fft(X);
|
||||
|
||||
float* X_p = X;
|
||||
RTC_DCHECK_EQ(X_p, &X[0]);
|
||||
spectrum[0] = (*X_p) * (*X_p);
|
||||
++X_p;
|
||||
RTC_DCHECK_EQ(X_p, &X[1]);
|
||||
spectrum[64] = (*X_p) * (*X_p);
|
||||
for (int k = 1; k < 64; ++k) {
|
||||
++X_p;
|
||||
RTC_DCHECK_EQ(X_p, &X[2 * k]);
|
||||
spectrum[k] = (*X_p) * (*X_p);
|
||||
++X_p;
|
||||
RTC_DCHECK_EQ(X_p, &X[2 * k + 1]);
|
||||
spectrum[k] += (*X_p) * (*X_p);
|
||||
}
|
||||
}
|
||||
|
||||
webrtc::SignalClassifier::SignalType ClassifySignal(
|
||||
rtc::ArrayView<const float> signal_spectrum,
|
||||
rtc::ArrayView<const float> noise_spectrum,
|
||||
ApmDataDumper* data_dumper) {
|
||||
int num_stationary_bands = 0;
|
||||
int num_highly_nonstationary_bands = 0;
|
||||
|
||||
// Detect stationary and highly nonstationary bands.
|
||||
for (size_t k = 1; k < 40; k++) {
|
||||
if (signal_spectrum[k] < 3 * noise_spectrum[k] &&
|
||||
signal_spectrum[k] * 3 > noise_spectrum[k]) {
|
||||
++num_stationary_bands;
|
||||
} else if (signal_spectrum[k] > 9 * noise_spectrum[k]) {
|
||||
++num_highly_nonstationary_bands;
|
||||
}
|
||||
}
|
||||
|
||||
data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands);
|
||||
data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1,
|
||||
&num_highly_nonstationary_bands);
|
||||
|
||||
// Use the detected number of bands to classify the overall signal
|
||||
// stationarity.
|
||||
if (num_stationary_bands > 15) {
|
||||
return SignalClassifier::SignalType::kStationary;
|
||||
} else {
|
||||
return SignalClassifier::SignalType::kNonStationary;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SignalClassifier::FrameExtender::FrameExtender(size_t frame_size,
|
||||
size_t extended_frame_size)
|
||||
: x_old_(extended_frame_size - frame_size, 0.f) {}
|
||||
|
||||
SignalClassifier::FrameExtender::~FrameExtender() = default;
|
||||
|
||||
void SignalClassifier::FrameExtender::ExtendFrame(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> x_extended) {
|
||||
RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size());
|
||||
std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data());
|
||||
std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size());
|
||||
std::copy(x_extended.data() + x_extended.size() - x_old_.size(),
|
||||
x_extended.data() + x_extended.size(), x_old_.data());
|
||||
}
|
||||
|
||||
SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper)
|
||||
: data_dumper_(data_dumper),
|
||||
down_sampler_(data_dumper_),
|
||||
noise_spectrum_estimator_(data_dumper_),
|
||||
ooura_fft_(IsSse2Available()) {
|
||||
Initialize(48000);
|
||||
}
|
||||
SignalClassifier::~SignalClassifier() {}
|
||||
|
||||
void SignalClassifier::Initialize(int sample_rate_hz) {
|
||||
down_sampler_.Initialize(sample_rate_hz);
|
||||
noise_spectrum_estimator_.Initialize();
|
||||
frame_extender_.reset(new FrameExtender(80, 128));
|
||||
sample_rate_hz_ = sample_rate_hz;
|
||||
initialization_frames_left_ = 2;
|
||||
consistent_classification_counter_ = 3;
|
||||
last_signal_type_ = SignalClassifier::SignalType::kNonStationary;
|
||||
}
|
||||
|
||||
SignalClassifier::SignalType SignalClassifier::Analyze(
|
||||
rtc::ArrayView<const float> signal) {
|
||||
RTC_DCHECK_EQ(signal.size(), sample_rate_hz_ / 100);
|
||||
|
||||
// Compute the signal power spectrum.
|
||||
float downsampled_frame[80];
|
||||
down_sampler_.DownSample(signal, downsampled_frame);
|
||||
float extended_frame[128];
|
||||
frame_extender_->ExtendFrame(downsampled_frame, extended_frame);
|
||||
RemoveDcLevel(extended_frame);
|
||||
float signal_spectrum[65];
|
||||
PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum);
|
||||
|
||||
// Classify the signal based on the estimate of the noise spectrum and the
|
||||
// signal spectrum estimate.
|
||||
const SignalType signal_type = ClassifySignal(
|
||||
signal_spectrum, noise_spectrum_estimator_.GetNoiseSpectrum(),
|
||||
data_dumper_);
|
||||
|
||||
// Update the noise spectrum based on the signal spectrum.
|
||||
noise_spectrum_estimator_.Update(signal_spectrum,
|
||||
initialization_frames_left_ > 0);
|
||||
|
||||
// Update the number of frames until a reliable signal spectrum is achieved.
|
||||
initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1);
|
||||
|
||||
if (last_signal_type_ == signal_type) {
|
||||
consistent_classification_counter_ =
|
||||
std::max(0, consistent_classification_counter_ - 1);
|
||||
} else {
|
||||
last_signal_type_ = signal_type;
|
||||
consistent_classification_counter_ = 3;
|
||||
}
|
||||
|
||||
if (consistent_classification_counter_ > 0) {
|
||||
return SignalClassifier::SignalType::kNonStationary;
|
||||
}
|
||||
return signal_type;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
73
webrtc/modules/audio_processing/agc2/signal_classifier.h
Normal file
73
webrtc/modules/audio_processing/agc2/signal_classifier.h
Normal file
@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/third_party/ooura/fft_size_128/ooura_fft.h"
|
||||
#include "modules/audio_processing/agc2/down_sampler.h"
|
||||
#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class ApmDataDumper;
|
||||
class AudioBuffer;
|
||||
|
||||
class SignalClassifier {
|
||||
public:
|
||||
enum class SignalType { kNonStationary, kStationary };
|
||||
|
||||
explicit SignalClassifier(ApmDataDumper* data_dumper);
|
||||
|
||||
SignalClassifier() = delete;
|
||||
SignalClassifier(const SignalClassifier&) = delete;
|
||||
SignalClassifier& operator=(const SignalClassifier&) = delete;
|
||||
|
||||
~SignalClassifier();
|
||||
|
||||
void Initialize(int sample_rate_hz);
|
||||
SignalType Analyze(rtc::ArrayView<const float> signal);
|
||||
|
||||
private:
|
||||
class FrameExtender {
|
||||
public:
|
||||
FrameExtender(size_t frame_size, size_t extended_frame_size);
|
||||
|
||||
FrameExtender() = delete;
|
||||
FrameExtender(const FrameExtender&) = delete;
|
||||
FrameExtender& operator=(const FrameExtender&) = delete;
|
||||
|
||||
~FrameExtender();
|
||||
|
||||
void ExtendFrame(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> x_extended);
|
||||
|
||||
private:
|
||||
std::vector<float> x_old_;
|
||||
};
|
||||
|
||||
ApmDataDumper* const data_dumper_;
|
||||
DownSampler down_sampler_;
|
||||
std::unique_ptr<FrameExtender> frame_extender_;
|
||||
NoiseSpectrumEstimator noise_spectrum_estimator_;
|
||||
int sample_rate_hz_;
|
||||
int initialization_frames_left_;
|
||||
int consistent_classification_counter_;
|
||||
SignalType last_signal_type_;
|
||||
const OouraFft ooura_fft_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_
|
114
webrtc/modules/audio_processing/agc2/vad_with_level.cc
Normal file
114
webrtc/modules/audio_processing/agc2/vad_with_level.cc
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/vad_with_level.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "common_audio/resampler/include/push_resampler.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector;
|
||||
|
||||
// Default VAD that combines a resampler and the RNN VAD.
|
||||
// Computes the speech probability on the first channel.
|
||||
class Vad : public VoiceActivityDetector {
|
||||
public:
|
||||
Vad() = default;
|
||||
Vad(const Vad&) = delete;
|
||||
Vad& operator=(const Vad&) = delete;
|
||||
~Vad() = default;
|
||||
|
||||
float ComputeProbability(AudioFrameView<const float> frame) override {
|
||||
// The source number of channels is 1, because we always use the 1st
|
||||
// channel.
|
||||
resampler_.InitializeIfNeeded(
|
||||
/*sample_rate_hz=*/static_cast<int>(frame.samples_per_channel() * 100),
|
||||
rnn_vad::kSampleRate24kHz,
|
||||
/*num_channels=*/1);
|
||||
|
||||
std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
|
||||
// Feed the 1st channel to the resampler.
|
||||
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
|
||||
work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
|
||||
|
||||
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
|
||||
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
|
||||
work_frame, feature_vector);
|
||||
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
|
||||
}
|
||||
|
||||
private:
|
||||
PushResampler<float> resampler_;
|
||||
rnn_vad::FeaturesExtractor features_extractor_;
|
||||
rnn_vad::RnnBasedVad rnn_vad_;
|
||||
};
|
||||
|
||||
// Returns an updated version of `p_old` by using instant decay and the given
|
||||
// `attack` on a new VAD probability value `p_new`.
|
||||
float SmoothedVadProbability(float p_old, float p_new, float attack) {
|
||||
RTC_DCHECK_GT(attack, 0.f);
|
||||
RTC_DCHECK_LE(attack, 1.f);
|
||||
if (p_new < p_old || attack == 1.f) {
|
||||
// Instant decay (or no smoothing).
|
||||
return p_new;
|
||||
} else {
|
||||
// Attack phase.
|
||||
return attack * p_new + (1.f - attack) * p_old;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
VadLevelAnalyzer::VadLevelAnalyzer()
|
||||
: VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack,
|
||||
std::make_unique<Vad>()) {}
|
||||
|
||||
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack)
|
||||
: VadLevelAnalyzer(vad_probability_attack, std::make_unique<Vad>()) {}
|
||||
|
||||
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
|
||||
std::unique_ptr<VoiceActivityDetector> vad)
|
||||
: vad_(std::move(vad)), vad_probability_attack_(vad_probability_attack) {
|
||||
RTC_DCHECK(vad_);
|
||||
}
|
||||
|
||||
VadLevelAnalyzer::~VadLevelAnalyzer() = default;
|
||||
|
||||
VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
|
||||
AudioFrameView<const float> frame) {
|
||||
// Compute levels.
|
||||
float peak = 0.f;
|
||||
float rms = 0.f;
|
||||
for (const auto& x : frame.channel(0)) {
|
||||
peak = std::max(std::fabs(x), peak);
|
||||
rms += x * x;
|
||||
}
|
||||
// Compute smoothed speech probability.
|
||||
vad_probability_ = SmoothedVadProbability(
|
||||
/*p_old=*/vad_probability_, /*p_new=*/vad_->ComputeProbability(frame),
|
||||
vad_probability_attack_);
|
||||
return {vad_probability_,
|
||||
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
|
||||
FloatS16ToDbfs(peak)};
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
58
webrtc/modules/audio_processing/agc2/vad_with_level.h
Normal file
58
webrtc/modules/audio_processing/agc2/vad_with_level.h
Normal file
@ -0,0 +1,58 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Class to analyze voice activity and audio levels.
|
||||
class VadLevelAnalyzer {
|
||||
public:
|
||||
struct Result {
|
||||
float speech_probability; // Range: [0, 1].
|
||||
float rms_dbfs; // Root mean square power (dBFS).
|
||||
float peak_dbfs; // Peak power (dBFS).
|
||||
};
|
||||
|
||||
// Voice Activity Detector (VAD) interface.
|
||||
class VoiceActivityDetector {
|
||||
public:
|
||||
virtual ~VoiceActivityDetector() = default;
|
||||
// Analyzes an audio frame and returns the speech probability.
|
||||
virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
|
||||
};
|
||||
|
||||
// Ctor. Uses the default VAD.
|
||||
VadLevelAnalyzer();
|
||||
explicit VadLevelAnalyzer(float vad_probability_attack);
|
||||
// Ctor. Uses a custom `vad`.
|
||||
VadLevelAnalyzer(float vad_probability_attack,
|
||||
std::unique_ptr<VoiceActivityDetector> vad);
|
||||
VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
|
||||
VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
|
||||
~VadLevelAnalyzer();
|
||||
|
||||
// Computes the speech probability and the level for `frame`.
|
||||
Result AnalyzeFrame(AudioFrameView<const float> frame);
|
||||
|
||||
private:
|
||||
std::unique_ptr<VoiceActivityDetector> vad_;
|
||||
const float vad_probability_attack_;
|
||||
float vad_probability_ = 0.f;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
|
39
webrtc/modules/audio_processing/agc2/vector_float_frame.cc
Normal file
39
webrtc/modules/audio_processing/agc2/vector_float_frame.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/vector_float_frame.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<float*> ConstructChannelPointers(
|
||||
std::vector<std::vector<float>>* x) {
|
||||
std::vector<float*> channel_ptrs;
|
||||
for (auto& v : *x) {
|
||||
channel_ptrs.push_back(v.data());
|
||||
}
|
||||
return channel_ptrs;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VectorFloatFrame::VectorFloatFrame(int num_channels,
|
||||
int samples_per_channel,
|
||||
float start_value)
|
||||
: channels_(num_channels,
|
||||
std::vector<float>(samples_per_channel, start_value)),
|
||||
channel_ptrs_(ConstructChannelPointers(&channels_)),
|
||||
float_frame_view_(channel_ptrs_.data(),
|
||||
channels_.size(),
|
||||
samples_per_channel) {}
|
||||
|
||||
VectorFloatFrame::~VectorFloatFrame() = default;
|
||||
|
||||
} // namespace webrtc
|
42
webrtc/modules/audio_processing/agc2/vector_float_frame.h
Normal file
42
webrtc/modules/audio_processing/agc2/vector_float_frame.h
Normal file
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// A construct consisting of a multi-channel audio frame, and a FloatFrame view
|
||||
// of it.
|
||||
class VectorFloatFrame {
|
||||
public:
|
||||
VectorFloatFrame(int num_channels,
|
||||
int samples_per_channel,
|
||||
float start_value);
|
||||
const AudioFrameView<float>& float_frame_view() { return float_frame_view_; }
|
||||
AudioFrameView<const float> float_frame_view() const {
|
||||
return float_frame_view_;
|
||||
}
|
||||
|
||||
~VectorFloatFrame();
|
||||
|
||||
private:
|
||||
std::vector<std::vector<float>> channels_;
|
||||
std::vector<float*> channel_ptrs_;
|
||||
AudioFrameView<float> float_frame_view_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
|
Reference in New Issue
Block a user