Update to current webrtc library

This is from the upstream library commit id
3326535126e435f1ba647885ce43a8f0f3d317eb, corresponding to Chromium
88.0.4290.1.
This commit is contained in:
Arun Raghavan
2020-10-12 18:08:02 -04:00
parent b1b02581d3
commit bcec8b0b21
859 changed files with 76187 additions and 49580 deletions

View File

@ -0,0 +1,290 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
import("../../../webrtc.gni")
group("agc2") {
deps = [
":adaptive_digital",
":fixed_digital",
]
}
rtc_library("level_estimation_agc") {
sources = [
"adaptive_mode_level_estimator_agc.cc",
"adaptive_mode_level_estimator_agc.h",
]
configs += [ "..:apm_debug_dump" ]
deps = [
":adaptive_digital",
":common",
":gain_applier",
":noise_level_estimator",
":rnn_vad_with_level",
"..:api",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_minmax",
"../agc:level_estimation",
"../vad",
]
}
rtc_library("adaptive_digital") {
sources = [
"adaptive_agc.cc",
"adaptive_agc.h",
"adaptive_digital_gain_applier.cc",
"adaptive_digital_gain_applier.h",
"adaptive_mode_level_estimator.cc",
"adaptive_mode_level_estimator.h",
"saturation_protector.cc",
"saturation_protector.h",
]
configs += [ "..:apm_debug_dump" ]
deps = [
":common",
":gain_applier",
":noise_level_estimator",
":rnn_vad_with_level",
"..:api",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:logging",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_compare",
"../../../rtc_base:safe_minmax",
"../../../system_wrappers:metrics",
]
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
}
rtc_library("biquad_filter") {
visibility = [ "./*" ]
sources = [
"biquad_filter.cc",
"biquad_filter.h",
]
deps = [
"../../../api:array_view",
"../../../rtc_base:rtc_base_approved",
]
}
rtc_source_set("common") {
sources = [ "agc2_common.h" ]
}
rtc_library("fixed_digital") {
sources = [
"fixed_digital_level_estimator.cc",
"fixed_digital_level_estimator.h",
"interpolated_gain_curve.cc",
"interpolated_gain_curve.h",
"limiter.cc",
"limiter.h",
]
configs += [ "..:apm_debug_dump" ]
deps = [
":common",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:gtest_prod",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_minmax",
"../../../system_wrappers:metrics",
]
}
rtc_library("gain_applier") {
sources = [
"gain_applier.cc",
"gain_applier.h",
]
deps = [
":common",
"..:audio_frame_view",
"../../../api:array_view",
"../../../rtc_base:safe_minmax",
]
}
rtc_library("noise_level_estimator") {
sources = [
"down_sampler.cc",
"down_sampler.h",
"noise_level_estimator.cc",
"noise_level_estimator.h",
"noise_spectrum_estimator.cc",
"noise_spectrum_estimator.h",
"signal_classifier.cc",
"signal_classifier.h",
]
deps = [
":biquad_filter",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../common_audio/third_party/ooura:fft_size_128",
"../../../rtc_base:checks",
"../../../rtc_base:macromagic",
"../../../system_wrappers",
]
configs += [ "..:apm_debug_dump" ]
}
rtc_library("rnn_vad_with_level") {
sources = [
"vad_with_level.cc",
"vad_with_level.h",
]
deps = [
":common",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"rnn_vad",
"rnn_vad:rnn_vad_common",
]
}
rtc_library("adaptive_digital_unittests") {
testonly = true
configs += [ "..:apm_debug_dump" ]
sources = [
"adaptive_digital_gain_applier_unittest.cc",
"adaptive_mode_level_estimator_unittest.cc",
"gain_applier_unittest.cc",
"saturation_protector_unittest.cc",
]
deps = [
":adaptive_digital",
":common",
":gain_applier",
":test_utils",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:rtc_base_approved",
"../../../test:test_support",
]
}
rtc_library("biquad_filter_unittests") {
testonly = true
sources = [ "biquad_filter_unittest.cc" ]
deps = [
":biquad_filter",
"../../../rtc_base:gunit_helpers",
]
}
rtc_library("fixed_digital_unittests") {
testonly = true
configs += [ "..:apm_debug_dump" ]
sources = [
"agc2_testing_common_unittest.cc",
"compute_interpolated_gain_curve.cc",
"compute_interpolated_gain_curve.h",
"fixed_digital_level_estimator_unittest.cc",
"interpolated_gain_curve_unittest.cc",
"limiter_db_gain_curve.cc",
"limiter_db_gain_curve.h",
"limiter_db_gain_curve_unittest.cc",
"limiter_unittest.cc",
]
deps = [
":common",
":fixed_digital",
":test_utils",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:rtc_base_approved",
"../../../system_wrappers:metrics",
]
}
rtc_library("noise_estimator_unittests") {
testonly = true
configs += [ "..:apm_debug_dump" ]
sources = [
"noise_level_estimator_unittest.cc",
"signal_classifier_unittest.cc",
]
deps = [
":noise_level_estimator",
":test_utils",
"..:apm_logging",
"..:audio_frame_view",
"../../../api:array_view",
"../../../rtc_base:checks",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:rtc_base_approved",
]
}
rtc_library("rnn_vad_with_level_unittests") {
testonly = true
sources = [ "vad_with_level_unittest.cc" ]
deps = [
":common",
":rnn_vad_with_level",
"..:audio_frame_view",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:safe_compare",
"../../../test:test_support",
]
}
rtc_library("test_utils") {
testonly = true
visibility = [
":*",
"..:audio_processing_unittests",
]
sources = [
"agc2_testing_common.cc",
"agc2_testing_common.h",
"vector_float_frame.cc",
"vector_float_frame.h",
]
deps = [
"..:audio_frame_view",
"../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved",
]
}

View File

@ -0,0 +1,90 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_agc.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
namespace webrtc {
namespace {
void DumpDebugData(const AdaptiveDigitalGainApplier::FrameInfo& info,
ApmDataDumper& dumper) {
dumper.DumpRaw("agc2_vad_probability", info.vad_result.speech_probability);
dumper.DumpRaw("agc2_vad_rms_dbfs", info.vad_result.rms_dbfs);
dumper.DumpRaw("agc2_vad_peak_dbfs", info.vad_result.peak_dbfs);
dumper.DumpRaw("agc2_noise_estimate_dbfs", info.input_noise_level_dbfs);
dumper.DumpRaw("agc2_last_limiter_audio_level", info.limiter_envelope_dbfs);
}
constexpr int kGainApplierAdjacentSpeechFramesThreshold = 1;
constexpr float kMaxGainChangePerSecondDb = 3.f;
constexpr float kMaxOutputNoiseLevelDbfs = -50.f;
} // namespace
AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper)
: speech_level_estimator_(apm_data_dumper),
gain_applier_(apm_data_dumper,
kGainApplierAdjacentSpeechFramesThreshold,
kMaxGainChangePerSecondDb,
kMaxOutputNoiseLevelDbfs),
apm_data_dumper_(apm_data_dumper),
noise_level_estimator_(apm_data_dumper) {
RTC_DCHECK(apm_data_dumper);
}
AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2& config)
: speech_level_estimator_(
apm_data_dumper,
config.adaptive_digital.level_estimator,
config.adaptive_digital
.level_estimator_adjacent_speech_frames_threshold,
config.adaptive_digital.initial_saturation_margin_db,
config.adaptive_digital.extra_saturation_margin_db),
vad_(config.adaptive_digital.vad_probability_attack),
gain_applier_(
apm_data_dumper,
config.adaptive_digital.gain_applier_adjacent_speech_frames_threshold,
config.adaptive_digital.max_gain_change_db_per_second,
config.adaptive_digital.max_output_noise_level_dbfs),
apm_data_dumper_(apm_data_dumper),
noise_level_estimator_(apm_data_dumper) {
RTC_DCHECK(apm_data_dumper);
if (!config.adaptive_digital.use_saturation_protector) {
RTC_LOG(LS_WARNING) << "The saturation protector cannot be disabled.";
}
}
AdaptiveAgc::~AdaptiveAgc() = default;
void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
AdaptiveDigitalGainApplier::FrameInfo info;
info.vad_result = vad_.AnalyzeFrame(frame);
speech_level_estimator_.Update(info.vad_result);
info.input_level_dbfs = speech_level_estimator_.level_dbfs();
info.input_noise_level_dbfs = noise_level_estimator_.Analyze(frame);
info.limiter_envelope_dbfs =
limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.f;
info.estimate_is_confident = speech_level_estimator_.IsConfident();
DumpDebugData(info, *apm_data_dumper_);
gain_applier_.Process(info, frame);
}
void AdaptiveAgc::Reset() {
speech_level_estimator_.Reset();
}
} // namespace webrtc

View File

@ -0,0 +1,50 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/noise_level_estimator.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/include/audio_processing.h"
namespace webrtc {
class ApmDataDumper;
// Adaptive digital gain controller.
// TODO(crbug.com/webrtc/7494): Unify with `AdaptiveDigitalGainApplier`.
class AdaptiveAgc {
public:
explicit AdaptiveAgc(ApmDataDumper* apm_data_dumper);
// TODO(crbug.com/webrtc/7494): Remove ctor above.
AdaptiveAgc(ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2& config);
~AdaptiveAgc();
// Analyzes `frame` and applies a digital adaptive gain to it. Takes into
// account the envelope measured by the limiter.
// TODO(crbug.com/webrtc/7494): Make the class depend on the limiter.
void Process(AudioFrameView<float> frame, float limiter_envelope);
void Reset();
private:
AdaptiveModeLevelEstimator speech_level_estimator_;
VadLevelAnalyzer vad_;
AdaptiveDigitalGainApplier gain_applier_;
ApmDataDumper* const apm_data_dumper_;
NoiseLevelEstimator noise_level_estimator_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_

View File

@ -0,0 +1,179 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
#include <algorithm>
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
namespace {
// This function maps input level to desired applied gain. We want to
// boost the signal so that peaks are at -kHeadroomDbfs. We can't
// apply more than kMaxGainDb gain.
float ComputeGainDb(float input_level_dbfs) {
// If the level is very low, boost it as much as we can.
if (input_level_dbfs < -(kHeadroomDbfs + kMaxGainDb)) {
return kMaxGainDb;
}
// We expect to end up here most of the time: the level is below
// -headroom, but we can boost it to -headroom.
if (input_level_dbfs < -kHeadroomDbfs) {
return -kHeadroomDbfs - input_level_dbfs;
}
// Otherwise, the level is too high and we can't boost. The
// LevelEstimator is responsible for not reporting bogus gain
// values.
RTC_DCHECK_LE(input_level_dbfs, 0.f);
return 0.f;
}
// Returns `target_gain` if the output noise level is below
// `max_output_noise_level_dbfs`; otherwise returns a capped gain so that the
// output noise level equals `max_output_noise_level_dbfs`.
float LimitGainByNoise(float target_gain,
float input_noise_level_dbfs,
float max_output_noise_level_dbfs,
ApmDataDumper& apm_data_dumper) {
const float noise_headroom_db =
max_output_noise_level_dbfs - input_noise_level_dbfs;
apm_data_dumper.DumpRaw("agc2_noise_headroom_db", noise_headroom_db);
return std::min(target_gain, std::max(noise_headroom_db, 0.f));
}
float LimitGainByLowConfidence(float target_gain,
float last_gain,
float limiter_audio_level_dbfs,
bool estimate_is_confident) {
if (estimate_is_confident ||
limiter_audio_level_dbfs <= kLimiterThresholdForAgcGainDbfs) {
return target_gain;
}
const float limiter_level_before_gain = limiter_audio_level_dbfs - last_gain;
// Compute a new gain so that limiter_level_before_gain + new_gain <=
// kLimiterThreshold.
const float new_target_gain = std::max(
kLimiterThresholdForAgcGainDbfs - limiter_level_before_gain, 0.f);
return std::min(new_target_gain, target_gain);
}
// Computes how the gain should change during this frame.
// Return the gain difference in db to 'last_gain_db'.
float ComputeGainChangeThisFrameDb(float target_gain_db,
float last_gain_db,
bool gain_increase_allowed,
float max_gain_change_db) {
float target_gain_difference_db = target_gain_db - last_gain_db;
if (!gain_increase_allowed) {
target_gain_difference_db = std::min(target_gain_difference_db, 0.f);
}
return rtc::SafeClamp(target_gain_difference_db, -max_gain_change_db,
max_gain_change_db);
}
} // namespace
AdaptiveDigitalGainApplier::AdaptiveDigitalGainApplier(
ApmDataDumper* apm_data_dumper,
int adjacent_speech_frames_threshold,
float max_gain_change_db_per_second,
float max_output_noise_level_dbfs)
: apm_data_dumper_(apm_data_dumper),
gain_applier_(
/*hard_clip_samples=*/false,
/*initial_gain_factor=*/DbToRatio(kInitialAdaptiveDigitalGainDb)),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
max_gain_change_db_per_10ms_(max_gain_change_db_per_second *
kFrameDurationMs / 1000.f),
max_output_noise_level_dbfs_(max_output_noise_level_dbfs),
calls_since_last_gain_log_(0),
frames_to_gain_increase_allowed_(adjacent_speech_frames_threshold_),
last_gain_db_(kInitialAdaptiveDigitalGainDb) {
RTC_DCHECK_GT(max_gain_change_db_per_second, 0.f);
RTC_DCHECK_GE(frames_to_gain_increase_allowed_, 1);
RTC_DCHECK_GE(max_output_noise_level_dbfs_, -90.f);
RTC_DCHECK_LE(max_output_noise_level_dbfs_, 0.f);
}
void AdaptiveDigitalGainApplier::Process(const FrameInfo& info,
AudioFrameView<float> frame) {
RTC_DCHECK_GE(info.input_level_dbfs, -150.f);
RTC_DCHECK_GE(frame.num_channels(), 1);
RTC_DCHECK(
frame.samples_per_channel() == 80 || frame.samples_per_channel() == 160 ||
frame.samples_per_channel() == 320 || frame.samples_per_channel() == 480)
<< "`frame` does not look like a 10 ms frame for an APM supported sample "
"rate";
const float target_gain_db = LimitGainByLowConfidence(
LimitGainByNoise(ComputeGainDb(std::min(info.input_level_dbfs, 0.f)),
info.input_noise_level_dbfs,
max_output_noise_level_dbfs_, *apm_data_dumper_),
last_gain_db_, info.limiter_envelope_dbfs, info.estimate_is_confident);
// Forbid increasing the gain until enough adjacent speech frames are
// observed.
if (info.vad_result.speech_probability < kVadConfidenceThreshold) {
frames_to_gain_increase_allowed_ = adjacent_speech_frames_threshold_;
} else if (frames_to_gain_increase_allowed_ > 0) {
frames_to_gain_increase_allowed_--;
}
const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb(
target_gain_db, last_gain_db_,
/*gain_increase_allowed=*/frames_to_gain_increase_allowed_ == 0,
max_gain_change_db_per_10ms_);
apm_data_dumper_->DumpRaw("agc2_want_to_change_by_db",
target_gain_db - last_gain_db_);
apm_data_dumper_->DumpRaw("agc2_will_change_by_db",
gain_change_this_frame_db);
// Optimization: avoid calling math functions if gain does not
// change.
if (gain_change_this_frame_db != 0.f) {
gain_applier_.SetGainFactor(
DbToRatio(last_gain_db_ + gain_change_this_frame_db));
}
gain_applier_.ApplyGain(frame);
// Remember that the gain has changed for the next iteration.
last_gain_db_ = last_gain_db_ + gain_change_this_frame_db;
apm_data_dumper_->DumpRaw("agc2_applied_gain_db", last_gain_db_);
// Log every 10 seconds.
calls_since_last_gain_log_++;
if (calls_since_last_gain_log_ == 1000) {
calls_since_last_gain_log_ = 0;
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.DigitalGainApplied",
last_gain_db_, 0, kMaxGainDb, kMaxGainDb + 1);
RTC_HISTOGRAM_COUNTS_LINEAR(
"WebRTC.Audio.Agc2.EstimatedSpeechPlusNoiseLevel",
-info.input_level_dbfs, 0, 100, 101);
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedNoiseLevel",
-info.input_noise_level_dbfs, 0, 100, 101);
RTC_LOG(LS_INFO) << "AGC2 adaptive digital"
<< " | speech_plus_noise_dbfs: " << info.input_level_dbfs
<< " | noise_dbfs: " << info.input_noise_level_dbfs
<< " | gain_db: " << last_gain_db_;
}
}
} // namespace webrtc

View File

@ -0,0 +1,69 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_
#include "modules/audio_processing/agc2/gain_applier.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class ApmDataDumper;
// Part of the adaptive digital controller that applies a digital adaptive gain.
// The gain is updated towards a target. The logic decides when gain updates are
// allowed, it controls the adaptation speed and caps the target based on the
// estimated noise level and the speech level estimate confidence.
class AdaptiveDigitalGainApplier {
public:
// Information about a frame to process.
struct FrameInfo {
float input_level_dbfs; // Estimated speech plus noise level.
float input_noise_level_dbfs; // Estimated noise level.
VadLevelAnalyzer::Result vad_result;
float limiter_envelope_dbfs; // Envelope level from the limiter.
bool estimate_is_confident;
};
// Ctor.
// `adjacent_speech_frames_threshold` indicates how many speech frames are
// required before a gain increase is allowed. `max_gain_change_db_per_second`
// limits the adaptation speed (uniformly operated across frames).
// `max_output_noise_level_dbfs` limits the output noise level.
AdaptiveDigitalGainApplier(ApmDataDumper* apm_data_dumper,
int adjacent_speech_frames_threshold,
float max_gain_change_db_per_second,
float max_output_noise_level_dbfs);
AdaptiveDigitalGainApplier(const AdaptiveDigitalGainApplier&) = delete;
AdaptiveDigitalGainApplier& operator=(const AdaptiveDigitalGainApplier&) =
delete;
// Analyzes `info`, updates the digital gain and applies it to a 10 ms
// `frame`. Supports any sample rate supported by APM.
void Process(const FrameInfo& info, AudioFrameView<float> frame);
private:
ApmDataDumper* const apm_data_dumper_;
GainApplier gain_applier_;
const int adjacent_speech_frames_threshold_;
const float max_gain_change_db_per_10ms_;
const float max_output_noise_level_dbfs_;
int calls_since_last_gain_log_;
int frames_to_gain_increase_allowed_;
float last_gain_db_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_

View File

@ -0,0 +1,198 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
using LevelEstimatorType =
AudioProcessing::Config::GainController2::LevelEstimator;
// Combines a level estimation with the saturation protector margins.
float ComputeLevelEstimateDbfs(float level_estimate_dbfs,
float saturation_margin_db,
float extra_saturation_margin_db) {
return rtc::SafeClamp<float>(
level_estimate_dbfs + saturation_margin_db + extra_saturation_margin_db,
-90.f, 30.f);
}
// Returns the level of given type from `vad_level`.
float GetLevel(const VadLevelAnalyzer::Result& vad_level,
LevelEstimatorType type) {
switch (type) {
case LevelEstimatorType::kRms:
return vad_level.rms_dbfs;
break;
case LevelEstimatorType::kPeak:
return vad_level.peak_dbfs;
break;
}
}
} // namespace
bool AdaptiveModeLevelEstimator::LevelEstimatorState::operator==(
const AdaptiveModeLevelEstimator::LevelEstimatorState& b) const {
return time_to_full_buffer_ms == b.time_to_full_buffer_ms &&
level_dbfs.numerator == b.level_dbfs.numerator &&
level_dbfs.denominator == b.level_dbfs.denominator &&
saturation_protector == b.saturation_protector;
}
float AdaptiveModeLevelEstimator::LevelEstimatorState::Ratio::GetRatio() const {
RTC_DCHECK_NE(denominator, 0.f);
return numerator / denominator;
}
AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
ApmDataDumper* apm_data_dumper)
: AdaptiveModeLevelEstimator(
apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator::kRms,
kDefaultLevelEstimatorAdjacentSpeechFramesThreshold,
kDefaultInitialSaturationMarginDb,
kDefaultExtraSaturationMarginDb) {}
AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
ApmDataDumper* apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator level_estimator,
int adjacent_speech_frames_threshold,
float initial_saturation_margin_db,
float extra_saturation_margin_db)
: apm_data_dumper_(apm_data_dumper),
level_estimator_type_(level_estimator),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
initial_saturation_margin_db_(initial_saturation_margin_db),
extra_saturation_margin_db_(extra_saturation_margin_db),
level_dbfs_(ComputeLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs,
initial_saturation_margin_db_,
extra_saturation_margin_db_)) {
RTC_DCHECK(apm_data_dumper_);
RTC_DCHECK_GE(adjacent_speech_frames_threshold_, 1);
Reset();
}
void AdaptiveModeLevelEstimator::Update(
const VadLevelAnalyzer::Result& vad_level) {
RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f);
RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f);
RTC_DCHECK_GT(vad_level.peak_dbfs, -150.f);
RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f);
RTC_DCHECK_GE(vad_level.speech_probability, 0.f);
RTC_DCHECK_LE(vad_level.speech_probability, 1.f);
DumpDebugData();
if (vad_level.speech_probability < kVadConfidenceThreshold) {
// Not a speech frame.
if (adjacent_speech_frames_threshold_ > 1) {
// When two or more adjacent speech frames are required in order to update
// the state, we need to decide whether to discard or confirm the updates
// based on the speech sequence length.
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// First non-speech frame after a long enough sequence of speech frames.
// Update the reliable state.
reliable_state_ = preliminary_state_;
} else if (num_adjacent_speech_frames_ > 0) {
// First non-speech frame after a too short sequence of speech frames.
// Reset to the last reliable state.
preliminary_state_ = reliable_state_;
}
}
num_adjacent_speech_frames_ = 0;
return;
}
// Speech frame observed.
num_adjacent_speech_frames_++;
// Update preliminary level estimate.
RTC_DCHECK_GE(preliminary_state_.time_to_full_buffer_ms, 0);
const bool buffer_is_full = preliminary_state_.time_to_full_buffer_ms == 0;
if (!buffer_is_full) {
preliminary_state_.time_to_full_buffer_ms -= kFrameDurationMs;
}
// Weighted average of levels with speech probability as weight.
RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f;
preliminary_state_.level_dbfs.numerator =
preliminary_state_.level_dbfs.numerator * leak_factor +
GetLevel(vad_level, level_estimator_type_) * vad_level.speech_probability;
preliminary_state_.level_dbfs.denominator =
preliminary_state_.level_dbfs.denominator * leak_factor +
vad_level.speech_probability;
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
UpdateSaturationProtectorState(vad_level.peak_dbfs, level_dbfs,
preliminary_state_.saturation_protector);
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// `preliminary_state_` is now reliable. Update the last level estimation.
level_dbfs_ = ComputeLevelEstimateDbfs(
level_dbfs, preliminary_state_.saturation_protector.margin_db,
extra_saturation_margin_db_);
}
}
bool AdaptiveModeLevelEstimator::IsConfident() const {
if (adjacent_speech_frames_threshold_ == 1) {
// Ignore `reliable_state_` when a single frame is enough to update the
// level estimate (because it is not used).
return preliminary_state_.time_to_full_buffer_ms == 0;
}
// Once confident, it remains confident.
RTC_DCHECK(reliable_state_.time_to_full_buffer_ms != 0 ||
preliminary_state_.time_to_full_buffer_ms == 0);
// During the first long enough speech sequence, `reliable_state_` must be
// ignored since `preliminary_state_` is used.
return reliable_state_.time_to_full_buffer_ms == 0 ||
(num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_ &&
preliminary_state_.time_to_full_buffer_ms == 0);
}
void AdaptiveModeLevelEstimator::Reset() {
ResetLevelEstimatorState(preliminary_state_);
ResetLevelEstimatorState(reliable_state_);
level_dbfs_ = ComputeLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs,
initial_saturation_margin_db_,
extra_saturation_margin_db_);
num_adjacent_speech_frames_ = 0;
}
void AdaptiveModeLevelEstimator::ResetLevelEstimatorState(
LevelEstimatorState& state) const {
state.time_to_full_buffer_ms = kFullBufferSizeMs;
state.level_dbfs.numerator = 0.f;
state.level_dbfs.denominator = 0.f;
ResetSaturationProtectorState(initial_saturation_margin_db_,
state.saturation_protector);
}
void AdaptiveModeLevelEstimator::DumpDebugData() const {
apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs", level_dbfs_);
apm_data_dumper_->DumpRaw("agc2_adaptive_num_adjacent_speech_frames_",
num_adjacent_speech_frames_);
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_num",
preliminary_state_.level_dbfs.numerator);
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_den",
preliminary_state_.level_dbfs.denominator);
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_saturation_margin_db",
preliminary_state_.saturation_protector.margin_db);
}
} // namespace webrtc

View File

@ -0,0 +1,86 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
#include <stddef.h>
#include <type_traits>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_processing.h"
namespace webrtc {
class ApmDataDumper;
// Level estimator for the digital adaptive gain controller.
class AdaptiveModeLevelEstimator {
public:
explicit AdaptiveModeLevelEstimator(ApmDataDumper* apm_data_dumper);
AdaptiveModeLevelEstimator(const AdaptiveModeLevelEstimator&) = delete;
AdaptiveModeLevelEstimator& operator=(const AdaptiveModeLevelEstimator&) =
delete;
AdaptiveModeLevelEstimator(
ApmDataDumper* apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator level_estimator,
int adjacent_speech_frames_threshold,
float initial_saturation_margin_db,
float extra_saturation_margin_db);
// Updates the level estimation.
void Update(const VadLevelAnalyzer::Result& vad_data);
// Returns the estimated speech plus noise level.
float level_dbfs() const { return level_dbfs_; }
// Returns true if the estimator is confident on its current estimate.
bool IsConfident() const;
void Reset();
private:
// Part of the level estimator state used for check-pointing and restore ops.
struct LevelEstimatorState {
bool operator==(const LevelEstimatorState& s) const;
inline bool operator!=(const LevelEstimatorState& s) const {
return !(*this == s);
}
struct Ratio {
float numerator;
float denominator;
float GetRatio() const;
};
// TODO(crbug.com/webrtc/7494): Remove time_to_full_buffer_ms if redundant.
int time_to_full_buffer_ms;
Ratio level_dbfs;
SaturationProtectorState saturation_protector;
};
static_assert(std::is_trivially_copyable<LevelEstimatorState>::value, "");
void ResetLevelEstimatorState(LevelEstimatorState& state) const;
void DumpDebugData() const;
ApmDataDumper* const apm_data_dumper_;
const AudioProcessing::Config::GainController2::LevelEstimator
level_estimator_type_;
const int adjacent_speech_frames_threshold_;
const float initial_saturation_margin_db_;
const float extra_saturation_margin_db_;
LevelEstimatorState preliminary_state_;
LevelEstimatorState reliable_state_;
float level_dbfs_;
int num_adjacent_speech_frames_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_

View File

@ -0,0 +1,65 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.h"
#include <cmath>
#include <vector>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
AdaptiveModeLevelEstimatorAgc::AdaptiveModeLevelEstimatorAgc(
ApmDataDumper* apm_data_dumper)
: level_estimator_(apm_data_dumper) {
set_target_level_dbfs(kDefaultAgc2LevelHeadroomDbfs);
}
// |audio| must be mono; in a multi-channel stream, provide the first (usually
// left) channel.
void AdaptiveModeLevelEstimatorAgc::Process(const int16_t* audio,
size_t length,
int sample_rate_hz) {
std::vector<float> float_audio_frame(audio, audio + length);
const float* const first_channel = &float_audio_frame[0];
AudioFrameView<const float> frame_view(&first_channel, 1 /* num channels */,
length);
const auto vad_prob = agc2_vad_.AnalyzeFrame(frame_view);
latest_voice_probability_ = vad_prob.speech_probability;
if (latest_voice_probability_ > kVadConfidenceThreshold) {
time_in_ms_since_last_estimate_ += kFrameDurationMs;
}
level_estimator_.Update(vad_prob);
}
// Retrieves the difference between the target RMS level and the current
// signal RMS level in dB. Returns true if an update is available and false
// otherwise, in which case |error| should be ignored and no action taken.
bool AdaptiveModeLevelEstimatorAgc::GetRmsErrorDb(int* error) {
if (time_in_ms_since_last_estimate_ <= kTimeUntilConfidentMs) {
return false;
}
*error =
std::floor(target_level_dbfs() - level_estimator_.level_dbfs() + 0.5f);
time_in_ms_since_last_estimate_ = 0;
return true;
}
void AdaptiveModeLevelEstimatorAgc::Reset() {
level_estimator_.Reset();
}
float AdaptiveModeLevelEstimatorAgc::voice_probability() const {
return latest_voice_probability_;
}
} // namespace webrtc

View File

@ -0,0 +1,51 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_
#include <stddef.h>
#include <stdint.h>
#include "modules/audio_processing/agc/agc.h"
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
namespace webrtc {
class AdaptiveModeLevelEstimatorAgc : public Agc {
public:
explicit AdaptiveModeLevelEstimatorAgc(ApmDataDumper* apm_data_dumper);
// |audio| must be mono; in a multi-channel stream, provide the first (usually
// left) channel.
void Process(const int16_t* audio,
size_t length,
int sample_rate_hz) override;
// Retrieves the difference between the target RMS level and the current
// signal RMS level in dB. Returns true if an update is available and false
// otherwise, in which case |error| should be ignored and no action taken.
bool GetRmsErrorDb(int* error) override;
void Reset() override;
float voice_probability() const override;
private:
static constexpr int kTimeUntilConfidentMs = 700;
static constexpr int kDefaultAgc2LevelHeadroomDbfs = -1;
int32_t time_in_ms_since_last_estimate_ = 0;
AdaptiveModeLevelEstimator level_estimator_;
VadLevelAnalyzer agc2_vad_;
float latest_voice_probability_ = 0.f;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_

View File

@ -0,0 +1,86 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
#include <stddef.h>
namespace webrtc {
constexpr float kMinFloatS16Value = -32768.f;
constexpr float kMaxFloatS16Value = 32767.f;
constexpr float kMaxAbsFloatS16Value = 32768.0f;
constexpr size_t kFrameDurationMs = 10;
constexpr size_t kSubFramesInFrame = 20;
constexpr size_t kMaximalNumberOfSamplesPerChannel = 480;
constexpr float kAttackFilterConstant = 0.f;
// Adaptive digital gain applier settings below.
constexpr float kHeadroomDbfs = 1.f;
constexpr float kMaxGainDb = 30.f;
constexpr float kInitialAdaptiveDigitalGainDb = 8.f;
// At what limiter levels should we start decreasing the adaptive digital gain.
constexpr float kLimiterThresholdForAgcGainDbfs = -kHeadroomDbfs;
// This is the threshold for speech. Speech frames are used for updating the
// speech level, measuring the amount of speech, and decide when to allow target
// gain reduction.
constexpr float kVadConfidenceThreshold = 0.9f;
// The amount of 'memory' of the Level Estimator. Decides leak factors.
constexpr size_t kFullBufferSizeMs = 1200;
constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs;
constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
// Robust VAD probability and speech decisions.
constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f;
constexpr int kDefaultLevelEstimatorAdjacentSpeechFramesThreshold = 1;
// Saturation Protector settings.
constexpr float kDefaultInitialSaturationMarginDb = 20.f;
constexpr float kDefaultExtraSaturationMarginDb = 2.f;
constexpr size_t kPeakEnveloperSuperFrameLengthMs = 400;
static_assert(kFullBufferSizeMs % kPeakEnveloperSuperFrameLengthMs == 0,
"Full buffer size should be a multiple of super frame length for "
"optimal Saturation Protector performance.");
constexpr size_t kPeakEnveloperBufferSize =
kFullBufferSizeMs / kPeakEnveloperSuperFrameLengthMs + 1;
// This value is 10 ** (-1/20 * frame_size_ms / satproc_attack_ms),
// where satproc_attack_ms is 5000.
constexpr float kSaturationProtectorAttackConstant = 0.9988493699365052f;
// This value is 10 ** (-1/20 * frame_size_ms / satproc_decay_ms),
// where satproc_decay_ms is 1000.
constexpr float kSaturationProtectorDecayConstant = 0.9997697679981565f;
// This is computed from kDecayMs by
// 10 ** (-1/20 * subframe_duration / kDecayMs).
// |subframe_duration| is |kFrameDurationMs / kSubFramesInFrame|.
// kDecayMs is defined in agc2_testing_common.h
constexpr float kDecayFilterConstant = 0.9998848773724686f;
// Number of interpolation points for each region of the limiter.
// These values have been tuned to limit the interpolated gain curve error given
// the limiter parameters and allowing a maximum error of +/- 32768^-1.
constexpr size_t kInterpolatedGainCurveKneePoints = 22;
constexpr size_t kInterpolatedGainCurveBeyondKneePoints = 10;
constexpr size_t kInterpolatedGainCurveTotalPoints =
kInterpolatedGainCurveKneePoints + kInterpolatedGainCurveBeyondKneePoints;
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_

View File

@ -0,0 +1,33 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/agc2_testing_common.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace test {
std::vector<double> LinSpace(const double l,
const double r,
size_t num_points) {
RTC_CHECK(num_points >= 2);
std::vector<double> points(num_points);
const double step = (r - l) / (num_points - 1.0);
points[0] = l;
for (size_t i = 1; i < num_points - 1; i++) {
points[i] = static_cast<double>(l) + i * step;
}
points[num_points - 1] = r;
return points;
}
} // namespace test
} // namespace webrtc

View File

@ -0,0 +1,78 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
#include <math.h>
#include <limits>
#include <vector>
#include "rtc_base/checks.h"
namespace webrtc {
namespace test {
// Level Estimator test parameters.
constexpr float kDecayMs = 500.f;
// Limiter parameters.
constexpr float kLimiterMaxInputLevelDbFs = 1.f;
constexpr float kLimiterKneeSmoothnessDb = 1.f;
constexpr float kLimiterCompressionRatio = 5.f;
constexpr float kPi = 3.1415926536f;
std::vector<double> LinSpace(const double l, const double r, size_t num_points);
class SineGenerator {
public:
SineGenerator(float frequency, int rate)
: frequency_(frequency), rate_(rate) {}
float operator()() {
x_radians_ += frequency_ / rate_ * 2 * kPi;
if (x_radians_ > 2 * kPi) {
x_radians_ -= 2 * kPi;
}
return 1000.f * sinf(x_radians_);
}
private:
float frequency_;
int rate_;
float x_radians_ = 0.f;
};
class PulseGenerator {
public:
PulseGenerator(float frequency, int rate)
: samples_period_(
static_cast<int>(static_cast<float>(rate) / frequency)) {
RTC_DCHECK_GT(rate, frequency);
}
float operator()() {
sample_counter_++;
if (sample_counter_ >= samples_period_) {
sample_counter_ -= samples_period_;
}
return static_cast<float>(
sample_counter_ == 0 ? std::numeric_limits<int16_t>::max() : 10.f);
}
private:
int samples_period_;
int sample_counter_ = 0;
};
} // namespace test
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_

View File

@ -0,0 +1,36 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/biquad_filter.h"
#include <stddef.h>
namespace webrtc {
// Transposed direct form I implementation of a bi-quad filter applied to an
// input signal |x| to produce an output signal |y|.
void BiQuadFilter::Process(rtc::ArrayView<const float> x,
rtc::ArrayView<float> y) {
for (size_t k = 0; k < x.size(); ++k) {
// Use temporary variable for x[k] to allow in-place function call
// (that x and y refer to the same array).
const float tmp = x[k];
y[k] = coefficients_.b[0] * tmp + coefficients_.b[1] * biquad_state_.b[0] +
coefficients_.b[2] * biquad_state_.b[1] -
coefficients_.a[0] * biquad_state_.a[0] -
coefficients_.a[1] * biquad_state_.a[1];
biquad_state_.b[1] = biquad_state_.b[0];
biquad_state_.b[0] = tmp;
biquad_state_.a[1] = biquad_state_.a[0];
biquad_state_.a[0] = y[k];
}
}
} // namespace webrtc

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
#include <algorithm>
#include "api/array_view.h"
#include "rtc_base/arraysize.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
class BiQuadFilter {
public:
// Normalized filter coefficients.
// b_0 + b_1 • z^(-1) + b_2 • z^(-2)
// H(z) = ---------------------------------
// 1 + a_1 • z^(-1) + a_2 • z^(-2)
struct BiQuadCoefficients {
float b[3];
float a[2];
};
BiQuadFilter() = default;
void Initialize(const BiQuadCoefficients& coefficients) {
coefficients_ = coefficients;
}
void Reset() { biquad_state_.Reset(); }
// Produces a filtered output y of the input x. Both x and y need to
// have the same length. In-place modification is allowed.
void Process(rtc::ArrayView<const float> x, rtc::ArrayView<float> y);
private:
struct BiQuadState {
BiQuadState() { Reset(); }
void Reset() {
std::fill(b, b + arraysize(b), 0.f);
std::fill(a, a + arraysize(a), 0.f);
}
float b[2];
float a[2];
};
BiQuadState biquad_state_;
BiQuadCoefficients coefficients_;
RTC_DISALLOW_COPY_AND_ASSIGN(BiQuadFilter);
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_

View File

@ -0,0 +1,229 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/compute_interpolated_gain_curve.h"
#include <algorithm>
#include <cmath>
#include <queue>
#include <tuple>
#include <utility>
#include <vector>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/agc2_testing_common.h"
#include "modules/audio_processing/agc2/limiter_db_gain_curve.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
std::pair<double, double> ComputeLinearApproximationParams(
const LimiterDbGainCurve* limiter,
const double x) {
const double m = limiter->GetGainFirstDerivativeLinear(x);
const double q = limiter->GetGainLinear(x) - m * x;
return {m, q};
}
double ComputeAreaUnderPiecewiseLinearApproximation(
const LimiterDbGainCurve* limiter,
const double x0,
const double x1) {
RTC_CHECK_LT(x0, x1);
// Linear approximation in x0 and x1.
double m0, q0, m1, q1;
std::tie(m0, q0) = ComputeLinearApproximationParams(limiter, x0);
std::tie(m1, q1) = ComputeLinearApproximationParams(limiter, x1);
// Intersection point between two adjacent linear pieces.
RTC_CHECK_NE(m1, m0);
const double x_split = (q0 - q1) / (m1 - m0);
RTC_CHECK_LT(x0, x_split);
RTC_CHECK_LT(x_split, x1);
auto area_under_linear_piece = [](double x_l, double x_r, double m,
double q) {
return x_r * (m * x_r / 2.0 + q) - x_l * (m * x_l / 2.0 + q);
};
return area_under_linear_piece(x0, x_split, m0, q0) +
area_under_linear_piece(x_split, x1, m1, q1);
}
// Computes the approximation error in the limiter region for a given interval.
// The error is computed as the difference between the areas beneath the limiter
// curve to approximate and its linear under-approximation.
double LimiterUnderApproximationNegativeError(const LimiterDbGainCurve* limiter,
const double x0,
const double x1) {
const double area_limiter = limiter->GetGainIntegralLinear(x0, x1);
const double area_interpolated_curve =
ComputeAreaUnderPiecewiseLinearApproximation(limiter, x0, x1);
RTC_CHECK_GE(area_limiter, area_interpolated_curve);
return area_limiter - area_interpolated_curve;
}
// Automatically finds where to sample the beyond-knee region of a limiter using
// a greedy optimization algorithm that iteratively decreases the approximation
// error.
// The solution is sub-optimal because the algorithm is greedy and the points
// are assigned by halving intervals (starting with the whole beyond-knee region
// as a single interval). However, even if sub-optimal, this algorithm works
// well in practice and it is efficiently implemented using priority queues.
std::vector<double> SampleLimiterRegion(const LimiterDbGainCurve* limiter) {
static_assert(kInterpolatedGainCurveBeyondKneePoints > 2, "");
struct Interval {
Interval() = default; // Ctor required by std::priority_queue.
Interval(double l, double r, double e) : x0(l), x1(r), error(e) {
RTC_CHECK(x0 < x1);
}
bool operator<(const Interval& other) const { return error < other.error; }
double x0;
double x1;
double error;
};
std::priority_queue<Interval, std::vector<Interval>> q;
q.emplace(limiter->limiter_start_linear(), limiter->max_input_level_linear(),
LimiterUnderApproximationNegativeError(
limiter, limiter->limiter_start_linear(),
limiter->max_input_level_linear()));
// Iteratively find points by halving the interval with greatest error.
while (q.size() < kInterpolatedGainCurveBeyondKneePoints) {
// Get the interval with highest error.
const auto interval = q.top();
q.pop();
// Split |interval| and enqueue.
double x_split = (interval.x0 + interval.x1) / 2.0;
q.emplace(interval.x0, x_split,
LimiterUnderApproximationNegativeError(limiter, interval.x0,
x_split)); // Left.
q.emplace(x_split, interval.x1,
LimiterUnderApproximationNegativeError(limiter, x_split,
interval.x1)); // Right.
}
// Copy x1 values and sort them.
RTC_CHECK_EQ(q.size(), kInterpolatedGainCurveBeyondKneePoints);
std::vector<double> samples(kInterpolatedGainCurveBeyondKneePoints);
for (size_t i = 0; i < kInterpolatedGainCurveBeyondKneePoints; ++i) {
const auto interval = q.top();
q.pop();
samples[i] = interval.x1;
}
RTC_CHECK(q.empty());
std::sort(samples.begin(), samples.end());
return samples;
}
// Compute the parameters to over-approximate the knee region via linear
// interpolation. Over-approximating is saturation-safe since the knee region is
// convex.
void PrecomputeKneeApproxParams(const LimiterDbGainCurve* limiter,
test::InterpolatedParameters* parameters) {
static_assert(kInterpolatedGainCurveKneePoints > 2, "");
// Get |kInterpolatedGainCurveKneePoints| - 1 equally spaced points.
const std::vector<double> points = test::LinSpace(
limiter->knee_start_linear(), limiter->limiter_start_linear(),
kInterpolatedGainCurveKneePoints - 1);
// Set the first two points. The second is computed to help with the beginning
// of the knee region, which has high curvature.
parameters->computed_approximation_params_x[0] = points[0];
parameters->computed_approximation_params_x[1] =
(points[0] + points[1]) / 2.0;
// Copy the remaining points.
std::copy(std::begin(points) + 1, std::end(points),
std::begin(parameters->computed_approximation_params_x) + 2);
// Compute (m, q) pairs for each linear piece y = mx + q.
for (size_t i = 0; i < kInterpolatedGainCurveKneePoints - 1; ++i) {
const double x0 = parameters->computed_approximation_params_x[i];
const double x1 = parameters->computed_approximation_params_x[i + 1];
const double y0 = limiter->GetGainLinear(x0);
const double y1 = limiter->GetGainLinear(x1);
RTC_CHECK_NE(x1, x0);
parameters->computed_approximation_params_m[i] = (y1 - y0) / (x1 - x0);
parameters->computed_approximation_params_q[i] =
y0 - parameters->computed_approximation_params_m[i] * x0;
}
}
// Compute the parameters to under-approximate the beyond-knee region via linear
// interpolation and greedy sampling. Under-approximating is saturation-safe
// since the beyond-knee region is concave.
void PrecomputeBeyondKneeApproxParams(
const LimiterDbGainCurve* limiter,
test::InterpolatedParameters* parameters) {
// Find points on which the linear pieces are tangent to the gain curve.
const auto samples = SampleLimiterRegion(limiter);
// Parametrize each linear piece.
double m, q;
std::tie(m, q) = ComputeLinearApproximationParams(
limiter,
parameters
->computed_approximation_params_x[kInterpolatedGainCurveKneePoints -
1]);
parameters
->computed_approximation_params_m[kInterpolatedGainCurveKneePoints - 1] =
m;
parameters
->computed_approximation_params_q[kInterpolatedGainCurveKneePoints - 1] =
q;
for (size_t i = 0; i < samples.size(); ++i) {
std::tie(m, q) = ComputeLinearApproximationParams(limiter, samples[i]);
parameters
->computed_approximation_params_m[i +
kInterpolatedGainCurveKneePoints] = m;
parameters
->computed_approximation_params_q[i +
kInterpolatedGainCurveKneePoints] = q;
}
// Find the point of intersection between adjacent linear pieces. They will be
// used as boundaries between adjacent linear pieces.
for (size_t i = kInterpolatedGainCurveKneePoints;
i < kInterpolatedGainCurveKneePoints +
kInterpolatedGainCurveBeyondKneePoints;
++i) {
RTC_CHECK_NE(parameters->computed_approximation_params_m[i],
parameters->computed_approximation_params_m[i - 1]);
parameters->computed_approximation_params_x[i] =
( // Formula: (q0 - q1) / (m1 - m0).
parameters->computed_approximation_params_q[i - 1] -
parameters->computed_approximation_params_q[i]) /
(parameters->computed_approximation_params_m[i] -
parameters->computed_approximation_params_m[i - 1]);
}
}
} // namespace
namespace test {
InterpolatedParameters ComputeInterpolatedGainCurveApproximationParams() {
InterpolatedParameters parameters;
LimiterDbGainCurve limiter;
parameters.computed_approximation_params_x.fill(0.0f);
parameters.computed_approximation_params_m.fill(0.0f);
parameters.computed_approximation_params_q.fill(0.0f);
PrecomputeKneeApproxParams(&limiter, &parameters);
PrecomputeBeyondKneeApproxParams(&limiter, &parameters);
return parameters;
}
} // namespace test
} // namespace webrtc

View File

@ -0,0 +1,48 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
#define MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
#include <array>
#include "modules/audio_processing/agc2/agc2_common.h"
namespace webrtc {
namespace test {
// Parameters for interpolated gain curve using under-approximation to
// avoid saturation.
//
// The saturation gain is defined in order to let hard-clipping occur for
// those samples having a level that falls in the saturation region. It is an
// upper bound of the actual gain to apply - i.e., that returned by the
// limiter.
// Knee and beyond-knee regions approximation parameters.
// The gain curve is approximated as a piece-wise linear function.
// |approx_params_x_| are the boundaries between adjacent linear pieces,
// |approx_params_m_| and |approx_params_q_| are the slope and the y-intercept
// values of each piece.
struct InterpolatedParameters {
std::array<float, kInterpolatedGainCurveTotalPoints>
computed_approximation_params_x;
std::array<float, kInterpolatedGainCurveTotalPoints>
computed_approximation_params_m;
std::array<float, kInterpolatedGainCurveTotalPoints>
computed_approximation_params_q;
};
InterpolatedParameters ComputeInterpolatedGainCurveApproximationParams();
} // namespace test
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_

View File

@ -0,0 +1,99 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/down_sampler.h"
#include <string.h>
#include <algorithm>
#include "modules/audio_processing/agc2/biquad_filter.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr int kChunkSizeMs = 10;
constexpr int kSampleRate8kHz = 8000;
constexpr int kSampleRate16kHz = 16000;
constexpr int kSampleRate32kHz = 32000;
constexpr int kSampleRate48kHz = 48000;
// Bandlimiter coefficients computed based on that only
// the first 40 bins of the spectrum for the downsampled
// signal are used.
// [B,A] = butter(2,(41/64*4000)/8000)
const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_16kHz = {
{0.1455f, 0.2911f, 0.1455f},
{-0.6698f, 0.2520f}};
// [B,A] = butter(2,(41/64*4000)/16000)
const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_32kHz = {
{0.0462f, 0.0924f, 0.0462f},
{-1.3066f, 0.4915f}};
// [B,A] = butter(2,(41/64*4000)/24000)
const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_48kHz = {
{0.0226f, 0.0452f, 0.0226f},
{-1.5320f, 0.6224f}};
} // namespace
DownSampler::DownSampler(ApmDataDumper* data_dumper)
: data_dumper_(data_dumper) {
Initialize(48000);
}
void DownSampler::Initialize(int sample_rate_hz) {
RTC_DCHECK(
sample_rate_hz == kSampleRate8kHz || sample_rate_hz == kSampleRate16kHz ||
sample_rate_hz == kSampleRate32kHz || sample_rate_hz == kSampleRate48kHz);
sample_rate_hz_ = sample_rate_hz;
down_sampling_factor_ = rtc::CheckedDivExact(sample_rate_hz_, 8000);
/// Note that the down sampling filter is not used if the sample rate is 8
/// kHz.
if (sample_rate_hz_ == kSampleRate16kHz) {
low_pass_filter_.Initialize(kLowPassFilterCoefficients_16kHz);
} else if (sample_rate_hz_ == kSampleRate32kHz) {
low_pass_filter_.Initialize(kLowPassFilterCoefficients_32kHz);
} else if (sample_rate_hz_ == kSampleRate48kHz) {
low_pass_filter_.Initialize(kLowPassFilterCoefficients_48kHz);
}
}
void DownSampler::DownSample(rtc::ArrayView<const float> in,
rtc::ArrayView<float> out) {
data_dumper_->DumpWav("lc_down_sampler_input", in, sample_rate_hz_, 1);
RTC_DCHECK_EQ(sample_rate_hz_ * kChunkSizeMs / 1000, in.size());
RTC_DCHECK_EQ(kSampleRate8kHz * kChunkSizeMs / 1000, out.size());
const size_t kMaxNumFrames = kSampleRate48kHz * kChunkSizeMs / 1000;
float x[kMaxNumFrames];
// Band-limit the signal to 4 kHz.
if (sample_rate_hz_ != kSampleRate8kHz) {
low_pass_filter_.Process(in, rtc::ArrayView<float>(x, in.size()));
// Downsample the signal.
size_t k = 0;
for (size_t j = 0; j < out.size(); ++j) {
RTC_DCHECK_GT(kMaxNumFrames, k);
out[j] = x[k];
k += down_sampling_factor_;
}
} else {
std::copy(in.data(), in.data() + in.size(), out.data());
}
data_dumper_->DumpWav("lc_down_sampler_output", out, kSampleRate8kHz, 1);
}
} // namespace webrtc

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_
#include "api/array_view.h"
#include "modules/audio_processing/agc2/biquad_filter.h"
namespace webrtc {
class ApmDataDumper;
class DownSampler {
public:
explicit DownSampler(ApmDataDumper* data_dumper);
DownSampler() = delete;
DownSampler(const DownSampler&) = delete;
DownSampler& operator=(const DownSampler&) = delete;
void Initialize(int sample_rate_hz);
void DownSample(rtc::ArrayView<const float> in, rtc::ArrayView<float> out);
private:
ApmDataDumper* data_dumper_;
int sample_rate_hz_;
int down_sampling_factor_;
BiQuadFilter low_pass_filter_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_

View File

@ -0,0 +1,112 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/fixed_digital_level_estimator.h"
#include <algorithm>
#include <cmath>
#include "api/array_view.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr float kInitialFilterStateLevel = 0.f;
} // namespace
FixedDigitalLevelEstimator::FixedDigitalLevelEstimator(
size_t sample_rate_hz,
ApmDataDumper* apm_data_dumper)
: apm_data_dumper_(apm_data_dumper),
filter_state_level_(kInitialFilterStateLevel) {
SetSampleRate(sample_rate_hz);
CheckParameterCombination();
RTC_DCHECK(apm_data_dumper_);
apm_data_dumper_->DumpRaw("agc2_level_estimator_samplerate", sample_rate_hz);
}
void FixedDigitalLevelEstimator::CheckParameterCombination() {
RTC_DCHECK_GT(samples_in_frame_, 0);
RTC_DCHECK_LE(kSubFramesInFrame, samples_in_frame_);
RTC_DCHECK_EQ(samples_in_frame_ % kSubFramesInFrame, 0);
RTC_DCHECK_GT(samples_in_sub_frame_, 1);
}
std::array<float, kSubFramesInFrame> FixedDigitalLevelEstimator::ComputeLevel(
const AudioFrameView<const float>& float_frame) {
RTC_DCHECK_GT(float_frame.num_channels(), 0);
RTC_DCHECK_EQ(float_frame.samples_per_channel(), samples_in_frame_);
// Compute max envelope without smoothing.
std::array<float, kSubFramesInFrame> envelope{};
for (size_t channel_idx = 0; channel_idx < float_frame.num_channels();
++channel_idx) {
const auto channel = float_frame.channel(channel_idx);
for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
for (size_t sample_in_sub_frame = 0;
sample_in_sub_frame < samples_in_sub_frame_; ++sample_in_sub_frame) {
envelope[sub_frame] =
std::max(envelope[sub_frame],
std::abs(channel[sub_frame * samples_in_sub_frame_ +
sample_in_sub_frame]));
}
}
}
// Make sure envelope increases happen one step earlier so that the
// corresponding *gain decrease* doesn't miss a sudden signal
// increase due to interpolation.
for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame - 1; ++sub_frame) {
if (envelope[sub_frame] < envelope[sub_frame + 1]) {
envelope[sub_frame] = envelope[sub_frame + 1];
}
}
// Add attack / decay smoothing.
for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
const float envelope_value = envelope[sub_frame];
if (envelope_value > filter_state_level_) {
envelope[sub_frame] = envelope_value * (1 - kAttackFilterConstant) +
filter_state_level_ * kAttackFilterConstant;
} else {
envelope[sub_frame] = envelope_value * (1 - kDecayFilterConstant) +
filter_state_level_ * kDecayFilterConstant;
}
filter_state_level_ = envelope[sub_frame];
// Dump data for debug.
RTC_DCHECK(apm_data_dumper_);
const auto channel = float_frame.channel(0);
apm_data_dumper_->DumpRaw("agc2_level_estimator_samples",
samples_in_sub_frame_,
&channel[sub_frame * samples_in_sub_frame_]);
apm_data_dumper_->DumpRaw("agc2_level_estimator_level",
envelope[sub_frame]);
}
return envelope;
}
void FixedDigitalLevelEstimator::SetSampleRate(size_t sample_rate_hz) {
samples_in_frame_ = rtc::CheckedDivExact(sample_rate_hz * kFrameDurationMs,
static_cast<size_t>(1000));
samples_in_sub_frame_ =
rtc::CheckedDivExact(samples_in_frame_, kSubFramesInFrame);
CheckParameterCombination();
}
void FixedDigitalLevelEstimator::Reset() {
filter_state_level_ = kInitialFilterStateLevel;
}
} // namespace webrtc

View File

@ -0,0 +1,65 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
#include <array>
#include <vector>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
class ApmDataDumper;
// Produces a smooth signal level estimate from an input audio
// stream. The estimate smoothing is done through exponential
// filtering.
class FixedDigitalLevelEstimator {
public:
// Sample rates are allowed if the number of samples in a frame
// (sample_rate_hz * kFrameDurationMs / 1000) is divisible by
// kSubFramesInSample. For kFrameDurationMs=10 and
// kSubFramesInSample=20, this means that sample_rate_hz has to be
// divisible by 2000.
FixedDigitalLevelEstimator(size_t sample_rate_hz,
ApmDataDumper* apm_data_dumper);
// The input is assumed to be in FloatS16 format. Scaled input will
// produce similarly scaled output. A frame of with kFrameDurationMs
// ms of audio produces a level estimates in the same scale. The
// level estimate contains kSubFramesInFrame values.
std::array<float, kSubFramesInFrame> ComputeLevel(
const AudioFrameView<const float>& float_frame);
// Rate may be changed at any time (but not concurrently) from the
// value passed to the constructor. The class is not thread safe.
void SetSampleRate(size_t sample_rate_hz);
// Resets the level estimator internal state.
void Reset();
float LastAudioLevel() const { return filter_state_level_; }
private:
void CheckParameterCombination();
ApmDataDumper* const apm_data_dumper_ = nullptr;
float filter_state_level_;
size_t samples_in_frame_;
size_t samples_in_sub_frame_;
RTC_DISALLOW_COPY_AND_ASSIGN(FixedDigitalLevelEstimator);
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_

View File

@ -0,0 +1,101 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/fixed_gain_controller.h"
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
// Returns true when the gain factor is so close to 1 that it would
// not affect int16 samples.
bool CloseToOne(float gain_factor) {
return 1.f - 1.f / kMaxFloatS16Value <= gain_factor &&
gain_factor <= 1.f + 1.f / kMaxFloatS16Value;
}
} // namespace
FixedGainController::FixedGainController(ApmDataDumper* apm_data_dumper)
: FixedGainController(apm_data_dumper, "Agc2") {}
FixedGainController::FixedGainController(ApmDataDumper* apm_data_dumper,
std::string histogram_name_prefix)
: apm_data_dumper_(apm_data_dumper),
limiter_(48000, apm_data_dumper_, histogram_name_prefix) {
// Do update histograms.xml when adding name prefixes.
RTC_DCHECK(histogram_name_prefix == "" || histogram_name_prefix == "Test" ||
histogram_name_prefix == "AudioMixer" ||
histogram_name_prefix == "Agc2");
}
void FixedGainController::SetGain(float gain_to_apply_db) {
// Changes in gain_to_apply_ cause discontinuities. We assume
// gain_to_apply_ is set in the beginning of the call. If it is
// frequently changed, we should add interpolation between the
// values.
// The gain
RTC_DCHECK_LE(-50.f, gain_to_apply_db);
RTC_DCHECK_LE(gain_to_apply_db, 50.f);
const float previous_applied_gained = gain_to_apply_;
gain_to_apply_ = DbToRatio(gain_to_apply_db);
RTC_DCHECK_LT(0.f, gain_to_apply_);
RTC_DLOG(LS_INFO) << "Gain to apply: " << gain_to_apply_db << " db.";
// Reset the gain curve applier to quickly react on abrupt level changes
// caused by large changes of the applied gain.
if (previous_applied_gained != gain_to_apply_) {
limiter_.Reset();
}
}
void FixedGainController::SetSampleRate(size_t sample_rate_hz) {
limiter_.SetSampleRate(sample_rate_hz);
}
void FixedGainController::Process(AudioFrameView<float> signal) {
// Apply fixed digital gain. One of the
// planned usages of the FGC is to only use the limiter. In that
// case, the gain would be 1.0. Not doing the multiplications speeds
// it up considerably. Hence the check.
if (!CloseToOne(gain_to_apply_)) {
for (size_t k = 0; k < signal.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = signal.channel(k);
for (auto& sample : channel_view) {
sample *= gain_to_apply_;
}
}
}
// Use the limiter.
limiter_.Process(signal);
// Dump data for debug.
const auto channel_view = signal.channel(0);
apm_data_dumper_->DumpRaw("agc2_fixed_digital_gain_curve_applier",
channel_view.size(), channel_view.data());
// Hard-clipping.
for (size_t k = 0; k < signal.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = signal.channel(k);
for (auto& sample : channel_view) {
sample = rtc::SafeClamp(sample, kMinFloatS16Value, kMaxFloatS16Value);
}
}
}
float FixedGainController::LastAudioLevel() const {
return limiter_.LastAudioLevel();
}
} // namespace webrtc

View File

@ -0,0 +1,102 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/gain_applier.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
// Returns true when the gain factor is so close to 1 that it would
// not affect int16 samples.
bool GainCloseToOne(float gain_factor) {
return 1.f - 1.f / kMaxFloatS16Value <= gain_factor &&
gain_factor <= 1.f + 1.f / kMaxFloatS16Value;
}
void ClipSignal(AudioFrameView<float> signal) {
for (size_t k = 0; k < signal.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = signal.channel(k);
for (auto& sample : channel_view) {
sample = rtc::SafeClamp(sample, kMinFloatS16Value, kMaxFloatS16Value);
}
}
}
void ApplyGainWithRamping(float last_gain_linear,
float gain_at_end_of_frame_linear,
float inverse_samples_per_channel,
AudioFrameView<float> float_frame) {
// Do not modify the signal.
if (last_gain_linear == gain_at_end_of_frame_linear &&
GainCloseToOne(gain_at_end_of_frame_linear)) {
return;
}
// Gain is constant and different from 1.
if (last_gain_linear == gain_at_end_of_frame_linear) {
for (size_t k = 0; k < float_frame.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = float_frame.channel(k);
for (auto& sample : channel_view) {
sample *= gain_at_end_of_frame_linear;
}
}
return;
}
// The gain changes. We have to change slowly to avoid discontinuities.
const float increment = (gain_at_end_of_frame_linear - last_gain_linear) *
inverse_samples_per_channel;
float gain = last_gain_linear;
for (size_t i = 0; i < float_frame.samples_per_channel(); ++i) {
for (size_t ch = 0; ch < float_frame.num_channels(); ++ch) {
float_frame.channel(ch)[i] *= gain;
}
gain += increment;
}
}
} // namespace
GainApplier::GainApplier(bool hard_clip_samples, float initial_gain_factor)
: hard_clip_samples_(hard_clip_samples),
last_gain_factor_(initial_gain_factor),
current_gain_factor_(initial_gain_factor) {}
void GainApplier::ApplyGain(AudioFrameView<float> signal) {
if (static_cast<int>(signal.samples_per_channel()) != samples_per_channel_) {
Initialize(signal.samples_per_channel());
}
ApplyGainWithRamping(last_gain_factor_, current_gain_factor_,
inverse_samples_per_channel_, signal);
last_gain_factor_ = current_gain_factor_;
if (hard_clip_samples_) {
ClipSignal(signal);
}
}
void GainApplier::SetGainFactor(float gain_factor) {
RTC_DCHECK_GT(gain_factor, 0.f);
current_gain_factor_ = gain_factor;
}
void GainApplier::Initialize(size_t samples_per_channel) {
RTC_DCHECK_GT(samples_per_channel, 0);
samples_per_channel_ = static_cast<int>(samples_per_channel);
inverse_samples_per_channel_ = 1.f / samples_per_channel_;
}
} // namespace webrtc

View File

@ -0,0 +1,44 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
#include <stddef.h>
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class GainApplier {
public:
GainApplier(bool hard_clip_samples, float initial_gain_factor);
void ApplyGain(AudioFrameView<float> signal);
void SetGainFactor(float gain_factor);
float GetGainFactor() const { return current_gain_factor_; }
private:
void Initialize(size_t samples_per_channel);
// Whether to clip samples after gain is applied. If 'true', result
// will fit in FloatS16 range.
const bool hard_clip_samples_;
float last_gain_factor_;
// If this value is not equal to 'last_gain_factor', gain will be
// ramped from 'last_gain_factor_' to this value during the next
// 'ApplyGain'.
float current_gain_factor_;
int samples_per_channel_ = -1;
float inverse_samples_per_channel_ = -1.f;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_

View File

@ -0,0 +1,195 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/interpolated_gain_curve.h"
#include <algorithm>
#include <iterator>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
InterpolatedGainCurve::approximation_params_x_;
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
InterpolatedGainCurve::approximation_params_m_;
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
InterpolatedGainCurve::approximation_params_q_;
InterpolatedGainCurve::InterpolatedGainCurve(ApmDataDumper* apm_data_dumper,
std::string histogram_name_prefix)
: region_logger_("WebRTC.Audio." + histogram_name_prefix +
".FixedDigitalGainCurveRegion.Identity",
"WebRTC.Audio." + histogram_name_prefix +
".FixedDigitalGainCurveRegion.Knee",
"WebRTC.Audio." + histogram_name_prefix +
".FixedDigitalGainCurveRegion.Limiter",
"WebRTC.Audio." + histogram_name_prefix +
".FixedDigitalGainCurveRegion.Saturation"),
apm_data_dumper_(apm_data_dumper) {}
InterpolatedGainCurve::~InterpolatedGainCurve() {
if (stats_.available) {
RTC_DCHECK(apm_data_dumper_);
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_identity",
stats_.look_ups_identity_region);
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_knee",
stats_.look_ups_knee_region);
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_limiter",
stats_.look_ups_limiter_region);
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_saturation",
stats_.look_ups_saturation_region);
region_logger_.LogRegionStats(stats_);
}
}
InterpolatedGainCurve::RegionLogger::RegionLogger(
std::string identity_histogram_name,
std::string knee_histogram_name,
std::string limiter_histogram_name,
std::string saturation_histogram_name)
: identity_histogram(
metrics::HistogramFactoryGetCounts(identity_histogram_name,
1,
10000,
50)),
knee_histogram(metrics::HistogramFactoryGetCounts(knee_histogram_name,
1,
10000,
50)),
limiter_histogram(
metrics::HistogramFactoryGetCounts(limiter_histogram_name,
1,
10000,
50)),
saturation_histogram(
metrics::HistogramFactoryGetCounts(saturation_histogram_name,
1,
10000,
50)) {}
InterpolatedGainCurve::RegionLogger::~RegionLogger() = default;
void InterpolatedGainCurve::RegionLogger::LogRegionStats(
const InterpolatedGainCurve::Stats& stats) const {
using Region = InterpolatedGainCurve::GainCurveRegion;
const int duration_s =
stats.region_duration_frames / (1000 / kFrameDurationMs);
switch (stats.region) {
case Region::kIdentity: {
if (identity_histogram) {
metrics::HistogramAdd(identity_histogram, duration_s);
}
break;
}
case Region::kKnee: {
if (knee_histogram) {
metrics::HistogramAdd(knee_histogram, duration_s);
}
break;
}
case Region::kLimiter: {
if (limiter_histogram) {
metrics::HistogramAdd(limiter_histogram, duration_s);
}
break;
}
case Region::kSaturation: {
if (saturation_histogram) {
metrics::HistogramAdd(saturation_histogram, duration_s);
}
break;
}
default: {
RTC_NOTREACHED();
}
}
}
void InterpolatedGainCurve::UpdateStats(float input_level) const {
stats_.available = true;
GainCurveRegion region;
if (input_level < approximation_params_x_[0]) {
stats_.look_ups_identity_region++;
region = GainCurveRegion::kIdentity;
} else if (input_level <
approximation_params_x_[kInterpolatedGainCurveKneePoints - 1]) {
stats_.look_ups_knee_region++;
region = GainCurveRegion::kKnee;
} else if (input_level < kMaxInputLevelLinear) {
stats_.look_ups_limiter_region++;
region = GainCurveRegion::kLimiter;
} else {
stats_.look_ups_saturation_region++;
region = GainCurveRegion::kSaturation;
}
if (region == stats_.region) {
++stats_.region_duration_frames;
} else {
region_logger_.LogRegionStats(stats_);
stats_.region_duration_frames = 0;
stats_.region = region;
}
}
// Looks up a gain to apply given a non-negative input level.
// The cost of this operation depends on the region in which |input_level|
// falls.
// For the identity and the saturation regions the cost is O(1).
// For the other regions, namely knee and limiter, the cost is
// O(2 + log2(|LightkInterpolatedGainCurveTotalPoints|), plus O(1) for the
// linear interpolation (one product and one sum).
float InterpolatedGainCurve::LookUpGainToApply(float input_level) const {
UpdateStats(input_level);
if (input_level <= approximation_params_x_[0]) {
// Identity region.
return 1.0f;
}
if (input_level >= kMaxInputLevelLinear) {
// Saturating lower bound. The saturing samples exactly hit the clipping
// level. This method achieves has the lowest harmonic distorsion, but it
// may reduce the amplitude of the non-saturating samples too much.
return 32768.f / input_level;
}
// Knee and limiter regions; find the linear piece index. Spelling
// out the complete type was the only way to silence both the clang
// plugin and the windows compilers.
std::array<float, kInterpolatedGainCurveTotalPoints>::const_iterator it =
std::lower_bound(approximation_params_x_.begin(),
approximation_params_x_.end(), input_level);
const size_t index = std::distance(approximation_params_x_.begin(), it) - 1;
RTC_DCHECK_LE(0, index);
RTC_DCHECK_LT(index, approximation_params_m_.size());
RTC_DCHECK_LE(approximation_params_x_[index], input_level);
if (index < approximation_params_m_.size() - 1) {
RTC_DCHECK_LE(input_level, approximation_params_x_[index + 1]);
}
// Piece-wise linear interploation.
const float gain = approximation_params_m_[index] * input_level +
approximation_params_q_[index];
RTC_DCHECK_LE(0.f, gain);
return gain;
}
} // namespace webrtc

View File

@ -0,0 +1,152 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
#define MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
#include <array>
#include <string>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/constructor_magic.h"
#include "rtc_base/gtest_prod_util.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
class ApmDataDumper;
constexpr float kInputLevelScalingFactor = 32768.0f;
// Defined as DbfsToLinear(kLimiterMaxInputLevelDbFs)
constexpr float kMaxInputLevelLinear = static_cast<float>(36766.300710566735);
// Interpolated gain curve using under-approximation to avoid saturation.
//
// The goal of this class is allowing fast look ups to get an accurate
// estimates of the gain to apply given an estimated input level.
class InterpolatedGainCurve {
public:
enum class GainCurveRegion {
kIdentity = 0,
kKnee = 1,
kLimiter = 2,
kSaturation = 3
};
struct Stats {
// Region in which the output level equals the input one.
size_t look_ups_identity_region = 0;
// Smoothing between the identity and the limiter regions.
size_t look_ups_knee_region = 0;
// Limiter region in which the output and input levels are linearly related.
size_t look_ups_limiter_region = 0;
// Region in which saturation may occur since the input level is beyond the
// maximum expected by the limiter.
size_t look_ups_saturation_region = 0;
// True if stats have been populated.
bool available = false;
// The current region, and for how many frames the level has been
// in that region.
GainCurveRegion region = GainCurveRegion::kIdentity;
int64_t region_duration_frames = 0;
};
InterpolatedGainCurve(ApmDataDumper* apm_data_dumper,
std::string histogram_name_prefix);
~InterpolatedGainCurve();
Stats get_stats() const { return stats_; }
// Given a non-negative input level (linear scale), a scalar factor to apply
// to a sub-frame is returned.
// Levels above kLimiterMaxInputLevelDbFs will be reduced to 0 dBFS
// after applying this gain
float LookUpGainToApply(float input_level) const;
private:
// For comparing 'approximation_params_*_' with ones computed by
// ComputeInterpolatedGainCurve.
FRIEND_TEST_ALL_PREFIXES(AutomaticGainController2InterpolatedGainCurve,
CheckApproximationParams);
struct RegionLogger {
metrics::Histogram* identity_histogram;
metrics::Histogram* knee_histogram;
metrics::Histogram* limiter_histogram;
metrics::Histogram* saturation_histogram;
RegionLogger(std::string identity_histogram_name,
std::string knee_histogram_name,
std::string limiter_histogram_name,
std::string saturation_histogram_name);
~RegionLogger();
void LogRegionStats(const InterpolatedGainCurve::Stats& stats) const;
} region_logger_;
void UpdateStats(float input_level) const;
ApmDataDumper* const apm_data_dumper_;
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
approximation_params_x_ = {
{30057.296875, 30148.986328125, 30240.67578125, 30424.052734375,
30607.4296875, 30790.806640625, 30974.18359375, 31157.560546875,
31340.939453125, 31524.31640625, 31707.693359375, 31891.0703125,
32074.447265625, 32257.82421875, 32441.201171875, 32624.580078125,
32807.95703125, 32991.33203125, 33174.7109375, 33358.08984375,
33541.46484375, 33724.84375, 33819.53515625, 34009.5390625,
34200.05859375, 34389.81640625, 34674.48828125, 35054.375,
35434.86328125, 35814.81640625, 36195.16796875, 36575.03125}};
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
approximation_params_m_ = {
{-3.515235675877192989e-07, -1.050251626111275982e-06,
-2.085213736791047268e-06, -3.443004743530764244e-06,
-4.773849468620028347e-06, -6.077375928725814447e-06,
-7.353257842623861507e-06, -8.601219633419532329e-06,
-9.821013009059242904e-06, -1.101243378798244521e-05,
-1.217532644659513608e-05, -1.330956911260727793e-05,
-1.441507538402220234e-05, -1.549179251014720649e-05,
-1.653970684856176376e-05, -1.755882840370759368e-05,
-1.854918446042574942e-05, -1.951086778717581183e-05,
-2.044398024736437947e-05, -2.1348627342376858e-05,
-2.222496914328075945e-05, -2.265374678245279938e-05,
-2.242570917587727308e-05, -2.220122041762806475e-05,
-2.19802095671184361e-05, -2.176260204578284174e-05,
-2.133731686626560986e-05, -2.092481918225530535e-05,
-2.052459603874012828e-05, -2.013615448959171772e-05,
-1.975903069251216948e-05, -1.939277899509761482e-05}};
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
approximation_params_q_ = {
{1.010565876960754395, 1.031631827354431152, 1.062929749488830566,
1.104239225387573242, 1.144973039627075195, 1.185109615325927734,
1.224629044532775879, 1.263512492179870605, 1.301741957664489746,
1.339300632476806641, 1.376173257827758789, 1.412345528602600098,
1.447803974151611328, 1.482536554336547852, 1.516532182693481445,
1.549780607223510742, 1.582272171974182129, 1.613999366760253906,
1.644955039024353027, 1.675132393836975098, 1.704526185989379883,
1.718986630439758301, 1.711274504661560059, 1.703639745712280273,
1.696081161499023438, 1.688597679138183594, 1.673851132392883301,
1.659391283988952637, 1.645209431648254395, 1.631297469139099121,
1.617647409439086914, 1.604251742362976074}};
// Stats.
mutable Stats stats_;
RTC_DISALLOW_COPY_AND_ASSIGN(InterpolatedGainCurve);
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_

View File

@ -0,0 +1,150 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/limiter.h"
#include <algorithm>
#include <array>
#include <cmath>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
// This constant affects the way scaling factors are interpolated for the first
// sub-frame of a frame. Only in the case in which the first sub-frame has an
// estimated level which is greater than the that of the previous analyzed
// sub-frame, linear interpolation is replaced with a power function which
// reduces the chances of over-shooting (and hence saturation), however reducing
// the fixed gain effectiveness.
constexpr float kAttackFirstSubframeInterpolationPower = 8.f;
void InterpolateFirstSubframe(float last_factor,
float current_factor,
rtc::ArrayView<float> subframe) {
const auto n = subframe.size();
constexpr auto p = kAttackFirstSubframeInterpolationPower;
for (size_t i = 0; i < n; ++i) {
subframe[i] = std::pow(1.f - i / n, p) * (last_factor - current_factor) +
current_factor;
}
}
void ComputePerSampleSubframeFactors(
const std::array<float, kSubFramesInFrame + 1>& scaling_factors,
size_t samples_per_channel,
rtc::ArrayView<float> per_sample_scaling_factors) {
const size_t num_subframes = scaling_factors.size() - 1;
const size_t subframe_size =
rtc::CheckedDivExact(samples_per_channel, num_subframes);
// Handle first sub-frame differently in case of attack.
const bool is_attack = scaling_factors[0] > scaling_factors[1];
if (is_attack) {
InterpolateFirstSubframe(
scaling_factors[0], scaling_factors[1],
rtc::ArrayView<float>(
per_sample_scaling_factors.subview(0, subframe_size)));
}
for (size_t i = is_attack ? 1 : 0; i < num_subframes; ++i) {
const size_t subframe_start = i * subframe_size;
const float scaling_start = scaling_factors[i];
const float scaling_end = scaling_factors[i + 1];
const float scaling_diff = (scaling_end - scaling_start) / subframe_size;
for (size_t j = 0; j < subframe_size; ++j) {
per_sample_scaling_factors[subframe_start + j] =
scaling_start + scaling_diff * j;
}
}
}
void ScaleSamples(rtc::ArrayView<const float> per_sample_scaling_factors,
AudioFrameView<float> signal) {
const size_t samples_per_channel = signal.samples_per_channel();
RTC_DCHECK_EQ(samples_per_channel, per_sample_scaling_factors.size());
for (size_t i = 0; i < signal.num_channels(); ++i) {
auto channel = signal.channel(i);
for (size_t j = 0; j < samples_per_channel; ++j) {
channel[j] = rtc::SafeClamp(channel[j] * per_sample_scaling_factors[j],
kMinFloatS16Value, kMaxFloatS16Value);
}
}
}
void CheckLimiterSampleRate(size_t sample_rate_hz) {
// Check that per_sample_scaling_factors_ is large enough.
RTC_DCHECK_LE(sample_rate_hz,
kMaximalNumberOfSamplesPerChannel * 1000 / kFrameDurationMs);
}
} // namespace
Limiter::Limiter(size_t sample_rate_hz,
ApmDataDumper* apm_data_dumper,
std::string histogram_name)
: interp_gain_curve_(apm_data_dumper, histogram_name),
level_estimator_(sample_rate_hz, apm_data_dumper),
apm_data_dumper_(apm_data_dumper) {
CheckLimiterSampleRate(sample_rate_hz);
}
Limiter::~Limiter() = default;
void Limiter::Process(AudioFrameView<float> signal) {
const auto level_estimate = level_estimator_.ComputeLevel(signal);
RTC_DCHECK_EQ(level_estimate.size() + 1, scaling_factors_.size());
scaling_factors_[0] = last_scaling_factor_;
std::transform(level_estimate.begin(), level_estimate.end(),
scaling_factors_.begin() + 1, [this](float x) {
return interp_gain_curve_.LookUpGainToApply(x);
});
const size_t samples_per_channel = signal.samples_per_channel();
RTC_DCHECK_LE(samples_per_channel, kMaximalNumberOfSamplesPerChannel);
auto per_sample_scaling_factors = rtc::ArrayView<float>(
&per_sample_scaling_factors_[0], samples_per_channel);
ComputePerSampleSubframeFactors(scaling_factors_, samples_per_channel,
per_sample_scaling_factors);
ScaleSamples(per_sample_scaling_factors, signal);
last_scaling_factor_ = scaling_factors_.back();
// Dump data for debug.
apm_data_dumper_->DumpRaw("agc2_gain_curve_applier_scaling_factors",
samples_per_channel,
per_sample_scaling_factors_.data());
}
InterpolatedGainCurve::Stats Limiter::GetGainCurveStats() const {
return interp_gain_curve_.get_stats();
}
void Limiter::SetSampleRate(size_t sample_rate_hz) {
CheckLimiterSampleRate(sample_rate_hz);
level_estimator_.SetSampleRate(sample_rate_hz);
}
void Limiter::Reset() {
level_estimator_.Reset();
}
float Limiter::LastAudioLevel() const {
return level_estimator_.LastAudioLevel();
}
} // namespace webrtc

View File

@ -0,0 +1,64 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
#include <string>
#include <vector>
#include "modules/audio_processing/agc2/fixed_digital_level_estimator.h"
#include "modules/audio_processing/agc2/interpolated_gain_curve.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
class ApmDataDumper;
class Limiter {
public:
Limiter(size_t sample_rate_hz,
ApmDataDumper* apm_data_dumper,
std::string histogram_name_prefix);
Limiter(const Limiter& limiter) = delete;
Limiter& operator=(const Limiter& limiter) = delete;
~Limiter();
// Applies limiter and hard-clipping to |signal|.
void Process(AudioFrameView<float> signal);
InterpolatedGainCurve::Stats GetGainCurveStats() const;
// Supported rates must be
// * supported by FixedDigitalLevelEstimator
// * below kMaximalNumberOfSamplesPerChannel*1000/kFrameDurationMs
// so that samples_per_channel fit in the
// per_sample_scaling_factors_ array.
void SetSampleRate(size_t sample_rate_hz);
// Resets the internal state.
void Reset();
float LastAudioLevel() const;
private:
const InterpolatedGainCurve interp_gain_curve_;
FixedDigitalLevelEstimator level_estimator_;
ApmDataDumper* const apm_data_dumper_ = nullptr;
// Work array containing the sub-frame scaling factors to be interpolated.
std::array<float, kSubFramesInFrame + 1> scaling_factors_ = {};
std::array<float, kMaximalNumberOfSamplesPerChannel>
per_sample_scaling_factors_ = {};
float last_scaling_factor_ = 1.f;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_

View File

@ -0,0 +1,138 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/limiter_db_gain_curve.h"
#include <cmath>
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
double ComputeKneeStart(double max_input_level_db,
double knee_smoothness_db,
double compression_ratio) {
RTC_CHECK_LT((compression_ratio - 1.0) * knee_smoothness_db /
(2.0 * compression_ratio),
max_input_level_db);
return -knee_smoothness_db / 2.0 -
max_input_level_db / (compression_ratio - 1.0);
}
std::array<double, 3> ComputeKneeRegionPolynomial(double knee_start_dbfs,
double knee_smoothness_db,
double compression_ratio) {
const double a = (1.0 - compression_ratio) /
(2.0 * knee_smoothness_db * compression_ratio);
const double b = 1.0 - 2.0 * a * knee_start_dbfs;
const double c = a * knee_start_dbfs * knee_start_dbfs;
return {{a, b, c}};
}
double ComputeLimiterD1(double max_input_level_db, double compression_ratio) {
return (std::pow(10.0, -max_input_level_db / (20.0 * compression_ratio)) *
(1.0 - compression_ratio) / compression_ratio) /
kMaxAbsFloatS16Value;
}
constexpr double ComputeLimiterD2(double compression_ratio) {
return (1.0 - 2.0 * compression_ratio) / compression_ratio;
}
double ComputeLimiterI2(double max_input_level_db,
double compression_ratio,
double gain_curve_limiter_i1) {
RTC_CHECK_NE(gain_curve_limiter_i1, 0.f);
return std::pow(10.0, -max_input_level_db / (20.0 * compression_ratio)) /
gain_curve_limiter_i1 /
std::pow(kMaxAbsFloatS16Value, gain_curve_limiter_i1 - 1);
}
} // namespace
LimiterDbGainCurve::LimiterDbGainCurve()
: max_input_level_linear_(DbfsToFloatS16(max_input_level_db_)),
knee_start_dbfs_(ComputeKneeStart(max_input_level_db_,
knee_smoothness_db_,
compression_ratio_)),
knee_start_linear_(DbfsToFloatS16(knee_start_dbfs_)),
limiter_start_dbfs_(knee_start_dbfs_ + knee_smoothness_db_),
limiter_start_linear_(DbfsToFloatS16(limiter_start_dbfs_)),
knee_region_polynomial_(ComputeKneeRegionPolynomial(knee_start_dbfs_,
knee_smoothness_db_,
compression_ratio_)),
gain_curve_limiter_d1_(
ComputeLimiterD1(max_input_level_db_, compression_ratio_)),
gain_curve_limiter_d2_(ComputeLimiterD2(compression_ratio_)),
gain_curve_limiter_i1_(1.0 / compression_ratio_),
gain_curve_limiter_i2_(ComputeLimiterI2(max_input_level_db_,
compression_ratio_,
gain_curve_limiter_i1_)) {
static_assert(knee_smoothness_db_ > 0.0f, "");
static_assert(compression_ratio_ > 1.0f, "");
RTC_CHECK_GE(max_input_level_db_, knee_start_dbfs_ + knee_smoothness_db_);
}
constexpr double LimiterDbGainCurve::max_input_level_db_;
constexpr double LimiterDbGainCurve::knee_smoothness_db_;
constexpr double LimiterDbGainCurve::compression_ratio_;
double LimiterDbGainCurve::GetOutputLevelDbfs(double input_level_dbfs) const {
if (input_level_dbfs < knee_start_dbfs_) {
return input_level_dbfs;
} else if (input_level_dbfs < limiter_start_dbfs_) {
return GetKneeRegionOutputLevelDbfs(input_level_dbfs);
}
return GetCompressorRegionOutputLevelDbfs(input_level_dbfs);
}
double LimiterDbGainCurve::GetGainLinear(double input_level_linear) const {
if (input_level_linear < knee_start_linear_) {
return 1.0;
}
return DbfsToFloatS16(
GetOutputLevelDbfs(FloatS16ToDbfs(input_level_linear))) /
input_level_linear;
}
// Computes the first derivative of GetGainLinear() in |x|.
double LimiterDbGainCurve::GetGainFirstDerivativeLinear(double x) const {
// Beyond-knee region only.
RTC_CHECK_GE(x, limiter_start_linear_ - 1e-7 * kMaxAbsFloatS16Value);
return gain_curve_limiter_d1_ *
std::pow(x / kMaxAbsFloatS16Value, gain_curve_limiter_d2_);
}
// Computes the integral of GetGainLinear() in the range [x0, x1].
double LimiterDbGainCurve::GetGainIntegralLinear(double x0, double x1) const {
RTC_CHECK_LE(x0, x1); // Valid interval.
RTC_CHECK_GE(x0, limiter_start_linear_); // Beyond-knee region only.
auto limiter_integral = [this](const double& x) {
return gain_curve_limiter_i2_ * std::pow(x, gain_curve_limiter_i1_);
};
return limiter_integral(x1) - limiter_integral(x0);
}
double LimiterDbGainCurve::GetKneeRegionOutputLevelDbfs(
double input_level_dbfs) const {
return knee_region_polynomial_[0] * input_level_dbfs * input_level_dbfs +
knee_region_polynomial_[1] * input_level_dbfs +
knee_region_polynomial_[2];
}
double LimiterDbGainCurve::GetCompressorRegionOutputLevelDbfs(
double input_level_dbfs) const {
return (input_level_dbfs - max_input_level_db_) / compression_ratio_;
}
} // namespace webrtc

View File

@ -0,0 +1,76 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
#define MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
#include <array>
#include "modules/audio_processing/agc2/agc2_testing_common.h"
namespace webrtc {
// A class for computing a limiter gain curve (in dB scale) given a set of
// hard-coded parameters (namely, kLimiterDbGainCurveMaxInputLevelDbFs,
// kLimiterDbGainCurveKneeSmoothnessDb, and
// kLimiterDbGainCurveCompressionRatio). The generated curve consists of four
// regions: identity (linear), knee (quadratic polynomial), compression
// (linear), saturation (linear). The aforementioned constants are used to shape
// the different regions.
class LimiterDbGainCurve {
public:
LimiterDbGainCurve();
double max_input_level_db() const { return max_input_level_db_; }
double max_input_level_linear() const { return max_input_level_linear_; }
double knee_start_linear() const { return knee_start_linear_; }
double limiter_start_linear() const { return limiter_start_linear_; }
// These methods can be marked 'constexpr' in C++ 14.
double GetOutputLevelDbfs(double input_level_dbfs) const;
double GetGainLinear(double input_level_linear) const;
double GetGainFirstDerivativeLinear(double x) const;
double GetGainIntegralLinear(double x0, double x1) const;
private:
double GetKneeRegionOutputLevelDbfs(double input_level_dbfs) const;
double GetCompressorRegionOutputLevelDbfs(double input_level_dbfs) const;
static constexpr double max_input_level_db_ = test::kLimiterMaxInputLevelDbFs;
static constexpr double knee_smoothness_db_ = test::kLimiterKneeSmoothnessDb;
static constexpr double compression_ratio_ = test::kLimiterCompressionRatio;
const double max_input_level_linear_;
// Do not modify signal with level <= knee_start_dbfs_.
const double knee_start_dbfs_;
const double knee_start_linear_;
// The upper end of the knee region, which is between knee_start_dbfs_ and
// limiter_start_dbfs_.
const double limiter_start_dbfs_;
const double limiter_start_linear_;
// Coefficients {a, b, c} of the knee region polynomial
// ax^2 + bx + c in the DB scale.
const std::array<double, 3> knee_region_polynomial_;
// Parameters for the computation of the first derivative of GetGainLinear().
const double gain_curve_limiter_d1_;
const double gain_curve_limiter_d2_;
// Parameters for the computation of the integral of GetGainLinear().
const double gain_curve_limiter_i1_;
const double gain_curve_limiter_i2_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_

View File

@ -0,0 +1,114 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/noise_level_estimator.h"
#include <stddef.h>
#include <algorithm>
#include <cmath>
#include <numeric>
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr int kFramesPerSecond = 100;
float FrameEnergy(const AudioFrameView<const float>& audio) {
float energy = 0.f;
for (size_t k = 0; k < audio.num_channels(); ++k) {
float channel_energy =
std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.f,
[](float a, float b) -> float { return a + b * b; });
energy = std::max(channel_energy, energy);
}
return energy;
}
float EnergyToDbfs(float signal_energy, size_t num_samples) {
const float rms = std::sqrt(signal_energy / num_samples);
return FloatS16ToDbfs(rms);
}
} // namespace
NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper)
: signal_classifier_(data_dumper) {
Initialize(48000);
}
NoiseLevelEstimator::~NoiseLevelEstimator() {}
void NoiseLevelEstimator::Initialize(int sample_rate_hz) {
sample_rate_hz_ = sample_rate_hz;
noise_energy_ = 1.f;
first_update_ = true;
min_noise_energy_ = sample_rate_hz * 2.f * 2.f / kFramesPerSecond;
noise_energy_hold_counter_ = 0;
signal_classifier_.Initialize(sample_rate_hz);
}
float NoiseLevelEstimator::Analyze(const AudioFrameView<const float>& frame) {
const int rate =
static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
if (rate != sample_rate_hz_) {
Initialize(rate);
}
const float frame_energy = FrameEnergy(frame);
if (frame_energy <= 0.f) {
RTC_DCHECK_GE(frame_energy, 0.f);
return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
}
if (first_update_) {
// Initialize the noise energy to the frame energy.
first_update_ = false;
return EnergyToDbfs(
noise_energy_ = std::max(frame_energy, min_noise_energy_),
frame.samples_per_channel());
}
const SignalClassifier::SignalType signal_type =
signal_classifier_.Analyze(frame.channel(0));
// Update the noise estimate in a minimum statistics-type manner.
if (signal_type == SignalClassifier::SignalType::kStationary) {
if (frame_energy > noise_energy_) {
// Leak the estimate upwards towards the frame energy if no recent
// downward update.
noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0);
if (noise_energy_hold_counter_ == 0) {
noise_energy_ = std::min(noise_energy_ * 1.01f, frame_energy);
}
} else {
// Update smoothly downwards with a limited maximum update magnitude.
noise_energy_ =
std::max(noise_energy_ * 0.9f,
noise_energy_ + 0.05f * (frame_energy - noise_energy_));
noise_energy_hold_counter_ = 1000;
}
} else {
// For a non-stationary signal, leak the estimate downwards in order to
// avoid estimate locking due to incorrect signal classification.
noise_energy_ = noise_energy_ * 0.99f;
}
// Ensure a minimum of the estimate.
return EnergyToDbfs(
noise_energy_ = std::max(noise_energy_, min_noise_energy_),
frame.samples_per_channel());
}
} // namespace webrtc

View File

@ -0,0 +1,43 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
#include "modules/audio_processing/agc2/signal_classifier.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
class ApmDataDumper;
class NoiseLevelEstimator {
public:
NoiseLevelEstimator(ApmDataDumper* data_dumper);
~NoiseLevelEstimator();
// Returns the estimated noise level in dBFS.
float Analyze(const AudioFrameView<const float>& frame);
private:
void Initialize(int sample_rate_hz);
int sample_rate_hz_;
float min_noise_energy_;
bool first_update_;
float noise_energy_;
int noise_energy_hold_counter_;
SignalClassifier signal_classifier_;
RTC_DISALLOW_COPY_AND_ASSIGN(NoiseLevelEstimator);
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_

View File

@ -0,0 +1,70 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
#include <string.h>
#include <algorithm>
#include "api/array_view.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/arraysize.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr float kMinNoisePower = 100.f;
} // namespace
NoiseSpectrumEstimator::NoiseSpectrumEstimator(ApmDataDumper* data_dumper)
: data_dumper_(data_dumper) {
Initialize();
}
void NoiseSpectrumEstimator::Initialize() {
std::fill(noise_spectrum_, noise_spectrum_ + arraysize(noise_spectrum_),
kMinNoisePower);
}
void NoiseSpectrumEstimator::Update(rtc::ArrayView<const float> spectrum,
bool first_update) {
RTC_DCHECK_EQ(65, spectrum.size());
if (first_update) {
// Initialize the noise spectral estimate with the signal spectrum.
std::copy(spectrum.data(), spectrum.data() + spectrum.size(),
noise_spectrum_);
} else {
// Smoothly update the noise spectral estimate towards the signal spectrum
// such that the magnitude of the updates are limited.
for (size_t k = 0; k < spectrum.size(); ++k) {
if (noise_spectrum_[k] < spectrum[k]) {
noise_spectrum_[k] = std::min(
1.01f * noise_spectrum_[k],
noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k]));
} else {
noise_spectrum_[k] = std::max(
0.99f * noise_spectrum_[k],
noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k]));
}
}
}
// Ensure that the noise spectal estimate does not become too low.
for (auto& v : noise_spectrum_) {
v = std::max(v, kMinNoisePower);
}
data_dumper_->DumpRaw("lc_noise_spectrum", 65, noise_spectrum_);
data_dumper_->DumpRaw("lc_signal_spectrum", spectrum);
}
} // namespace webrtc

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_
#include "api/array_view.h"
namespace webrtc {
class ApmDataDumper;
class NoiseSpectrumEstimator {
public:
explicit NoiseSpectrumEstimator(ApmDataDumper* data_dumper);
NoiseSpectrumEstimator() = delete;
NoiseSpectrumEstimator(const NoiseSpectrumEstimator&) = delete;
NoiseSpectrumEstimator& operator=(const NoiseSpectrumEstimator&) = delete;
void Initialize();
void Update(rtc::ArrayView<const float> spectrum, bool first_update);
rtc::ArrayView<const float> GetNoiseSpectrum() const {
return rtc::ArrayView<const float>(noise_spectrum_);
}
private:
ApmDataDumper* data_dumper_;
float noise_spectrum_[65];
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_

View File

@ -0,0 +1,233 @@
# Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
import("../../../../webrtc.gni")
rtc_library("rnn_vad") {
visibility = [ "../*" ]
sources = [
"features_extraction.cc",
"features_extraction.h",
"rnn.cc",
"rnn.h",
]
if (rtc_build_with_neon && current_cpu != "arm64") {
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
cflags = [ "-mfpu=neon" ]
}
deps = [
":rnn_vad_common",
":rnn_vad_lp_residual",
":rnn_vad_pitch",
":rnn_vad_sequence_buffer",
":rnn_vad_spectral_features",
"..:biquad_filter",
"../../../../api:array_view",
"../../../../api:function_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:logging",
"../../../../rtc_base/system:arch",
"//third_party/rnnoise:rnn_vad",
]
}
rtc_library("rnn_vad_auto_correlation") {
sources = [
"auto_correlation.cc",
"auto_correlation.h",
]
deps = [
":rnn_vad_common",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../utility:pffft_wrapper",
]
}
rtc_library("rnn_vad_common") {
# TODO(alessiob): Make this target visibility private.
visibility = [
":*",
"..:rnn_vad_with_level",
]
sources = [
"common.cc",
"common.h",
]
deps = [
"../../../../rtc_base/system:arch",
"../../../../system_wrappers",
]
}
rtc_library("rnn_vad_lp_residual") {
sources = [
"lp_residual.cc",
"lp_residual.h",
]
deps = [
"../../../../api:array_view",
"../../../../rtc_base:checks",
]
}
rtc_library("rnn_vad_pitch") {
sources = [
"pitch_info.h",
"pitch_search.cc",
"pitch_search.h",
"pitch_search_internal.cc",
"pitch_search_internal.h",
]
deps = [
":rnn_vad_auto_correlation",
":rnn_vad_common",
"../../../../api:array_view",
"../../../../rtc_base:checks",
]
}
rtc_source_set("rnn_vad_ring_buffer") {
sources = [ "ring_buffer.h" ]
deps = [
"../../../../api:array_view",
"../../../../rtc_base:checks",
]
}
rtc_source_set("rnn_vad_sequence_buffer") {
sources = [ "sequence_buffer.h" ]
deps = [
"../../../../api:array_view",
"../../../../rtc_base:checks",
]
}
rtc_library("rnn_vad_spectral_features") {
sources = [
"spectral_features.cc",
"spectral_features.h",
"spectral_features_internal.cc",
"spectral_features_internal.h",
]
deps = [
":rnn_vad_common",
":rnn_vad_ring_buffer",
":rnn_vad_symmetric_matrix_buffer",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../utility:pffft_wrapper",
]
}
rtc_source_set("rnn_vad_symmetric_matrix_buffer") {
sources = [ "symmetric_matrix_buffer.h" ]
deps = [
"../../../../api:array_view",
"../../../../rtc_base:checks",
]
}
if (rtc_include_tests) {
rtc_library("test_utils") {
testonly = true
sources = [
"test_utils.cc",
"test_utils.h",
]
deps = [
":rnn_vad",
":rnn_vad_common",
"../../../../api:array_view",
"../../../../api:scoped_refptr",
"../../../../rtc_base:checks",
"../../../../rtc_base/system:arch",
"../../../../system_wrappers",
"../../../../test:fileutils",
"../../../../test:test_support",
]
}
unittest_resources = [
"../../../../resources/audio_processing/agc2/rnn_vad/band_energies.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_search_int.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/samples.pcm",
"../../../../resources/audio_processing/agc2/rnn_vad/vad_prob.dat",
]
if (is_ios) {
bundle_data("unittests_bundle_data") {
testonly = true
sources = unittest_resources
outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ]
}
}
rtc_library("unittests") {
testonly = true
sources = [
"auto_correlation_unittest.cc",
"features_extraction_unittest.cc",
"lp_residual_unittest.cc",
"pitch_search_internal_unittest.cc",
"pitch_search_unittest.cc",
"ring_buffer_unittest.cc",
"rnn_unittest.cc",
"rnn_vad_unittest.cc",
"sequence_buffer_unittest.cc",
"spectral_features_internal_unittest.cc",
"spectral_features_unittest.cc",
"symmetric_matrix_buffer_unittest.cc",
]
deps = [
":rnn_vad",
":rnn_vad_auto_correlation",
":rnn_vad_common",
":rnn_vad_lp_residual",
":rnn_vad_pitch",
":rnn_vad_ring_buffer",
":rnn_vad_sequence_buffer",
":rnn_vad_spectral_features",
":rnn_vad_symmetric_matrix_buffer",
":test_utils",
"../..:audioproc_test_utils",
"../../../../api:array_view",
"../../../../common_audio/",
"../../../../rtc_base:checks",
"../../../../rtc_base:logging",
"../../../../rtc_base/system:arch",
"../../../../test:test_support",
"../../utility:pffft_wrapper",
"//third_party/rnnoise:rnn_vad",
]
absl_deps = [ "//third_party/abseil-cpp/absl/memory" ]
data = unittest_resources
if (is_ios) {
deps += [ ":unittests_bundle_data" ]
}
}
rtc_executable("rnn_vad_tool") {
testonly = true
sources = [ "rnn_vad_tool.cc" ]
deps = [
":rnn_vad",
":rnn_vad_common",
"../../../../api:array_view",
"../../../../common_audio",
"../../../../rtc_base:rtc_base_approved",
"../../../../test:test_support",
"//third_party/abseil-cpp/absl/flags:flag",
"//third_party/abseil-cpp/absl/flags:parse",
]
}
}

View File

@ -0,0 +1,92 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include <algorithm>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
static_assert(1 << kAutoCorrelationFftOrder >
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
"");
} // namespace
AutoCorrelationCalculator::AutoCorrelationCalculator()
: fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
tmp_(fft_.CreateBuffer()),
X_(fft_.CreateBuffer()),
H_(fft_.CreateBuffer()) {}
AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
// The auto-correlations coefficients are computed as follows:
// |.........|...........| <- pitch buffer
// [ x (fixed) ]
// [ y_0 ]
// [ y_{m-1} ]
// x and y are sub-array of equal length; x is never moved, whereas y slides.
// The cross-correlation between y_0 and x corresponds to the auto-correlation
// for the maximum pitch period. Hence, the first value in |auto_corr| has an
// inverted lag equal to 0 that corresponds to a lag equal to the maximum
// pitch period.
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder;
constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
"Mismatch between pitch buffer size, frame size and maximum "
"pitch period.");
static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
"The FFT length is not sufficiently big to avoid cyclic "
"convolution errors.");
auto tmp = tmp_->GetView();
// Compute the FFT for the reversed reference frame - i.e.,
// pitch_buf[-kConvolutionLength:].
std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
tmp.begin());
std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
// Compute the FFT for the sliding frames chunk. The sliding frames are
// defined as pitch_buf[i:i+kConvolutionLength] where i in
// [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
std::copy(pitch_buf.begin(),
pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
tmp.begin());
std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
0.f);
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
// Convolve in the frequency domain.
constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
std::fill(tmp.begin(), tmp.end(), 0.f);
fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
// Extract the auto-correlation coefficients.
std::copy(tmp.begin() + kConvolutionLength - 1,
tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
auto_corr.begin());
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,49 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
#include <memory>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/utility/pffft_wrapper.h"
namespace webrtc {
namespace rnn_vad {
// Class to compute the auto correlation on the pitch buffer for a target pitch
// interval.
class AutoCorrelationCalculator {
public:
AutoCorrelationCalculator();
AutoCorrelationCalculator(const AutoCorrelationCalculator&) = delete;
AutoCorrelationCalculator& operator=(const AutoCorrelationCalculator&) =
delete;
~AutoCorrelationCalculator();
// Computes the auto-correlation coefficients for a target pitch interval.
// |auto_corr| indexes are inverted lags.
void ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
private:
Pffft fft_;
std::unique_ptr<Pffft::FloatBuffer> tmp_;
std::unique_ptr<Pffft::FloatBuffer> X_;
std::unique_ptr<Pffft::FloatBuffer> H_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_

View File

@ -0,0 +1,34 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/system/arch.h"
#include "system_wrappers/include/cpu_features_wrapper.h"
namespace webrtc {
namespace rnn_vad {
Optimization DetectOptimization() {
#if defined(WEBRTC_ARCH_X86_FAMILY)
if (GetCPUInfo(kSSE2) != 0) {
return Optimization::kSse2;
}
#endif
#if defined(WEBRTC_HAS_NEON)
return Optimization::kNeon;
#endif
return Optimization::kNone;
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,76 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
#include <stddef.h>
namespace webrtc {
namespace rnn_vad {
constexpr double kPi = 3.14159265358979323846;
constexpr size_t kSampleRate24kHz = 24000;
constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100;
constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
// Pitch buffer.
constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
// 24 kHz analysis.
// Define a higher minimum pitch period for the initial search. This is used to
// avoid searching for very short periods, for which a refinement step is
// responsible.
constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// 12 kHz analysis.
constexpr size_t kSampleRate12kHz = 12000;
constexpr size_t kFrameSize10ms12kHz = kSampleRate12kHz / 100;
constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
constexpr size_t kBufSize12kHz = kBufSize24kHz / 2;
constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// 48 kHz constants.
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2;
// Spectral features.
constexpr size_t kNumBands = 22;
constexpr size_t kNumLowerBands = 6;
static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), "");
constexpr size_t kCepstralCoeffsHistorySize = 8;
static_assert(kCepstralCoeffsHistorySize > 2,
"The history size must at least be 3 to compute first and second "
"derivatives.");
constexpr size_t kFeatureVectorSize = 42;
enum class Optimization { kNone, kSse2, kNeon };
// Detects what kind of optimizations to use for the code.
Optimization DetectOptimization();
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_

View File

@ -0,0 +1,90 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include <array>
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// Generated via "B, A = scipy.signal.butter(2, 30/12000, btype='highpass')"
const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = {
{0.99446179f, -1.98892358f, 0.99446179f},
{-1.98889291f, 0.98895425f}};
} // namespace
FeaturesExtractor::FeaturesExtractor()
: use_high_pass_filter_(false),
pitch_buf_24kHz_(),
pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()),
lp_residual_(kBufSize24kHz),
lp_residual_view_(lp_residual_.data(), kBufSize24kHz),
pitch_estimator_(),
reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) {
RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size());
hpf_.Initialize(kHpfConfig24k);
Reset();
}
FeaturesExtractor::~FeaturesExtractor() = default;
void FeaturesExtractor::Reset() {
pitch_buf_24kHz_.Reset();
spectral_features_extractor_.Reset();
if (use_high_pass_filter_)
hpf_.Reset();
}
bool FeaturesExtractor::CheckSilenceComputeFeatures(
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
// Pre-processing.
if (use_high_pass_filter_) {
std::array<float, kFrameSize10ms24kHz> samples_filtered;
hpf_.Process(samples, samples_filtered);
// Feed buffer with the pre-processed version of |samples|.
pitch_buf_24kHz_.Push(samples_filtered);
} else {
// Feed buffer with |samples|.
pitch_buf_24kHz_.Push(samples);
}
// Extract the LP residual.
float lpc_coeffs[kNumLpcCoefficients];
ComputeAndPostProcessLpcCoefficients(pitch_buf_24kHz_view_, lpc_coeffs);
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
// Estimate pitch on the LP-residual and write the normalized pitch period
// into the output vector (normalization based on training data stats).
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] =
0.01f * (static_cast<int>(pitch_info_48kHz_.period) - 300);
// Extract lagged frames (according to the estimated pitch period).
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
auto lagged_frame = pitch_buf_24kHz_view_.subview(
kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz);
// Analyze reference and lagged frames checking if silence has been detected
// and write the feature vector.
return spectral_features_extractor_.CheckSilenceComputeFeatures(
reference_frame_view_, {lagged_frame.data(), kFrameSize20ms24kHz},
{feature_vector.data() + kNumLowerBands, kNumBands - kNumLowerBands},
{feature_vector.data(), kNumLowerBands},
{feature_vector.data() + kNumBands, kNumLowerBands},
{feature_vector.data() + kNumBands + kNumLowerBands, kNumLowerBands},
{feature_vector.data() + kNumBands + 2 * kNumLowerBands, kNumLowerBands},
&feature_vector[kFeatureVectorSize - 1]);
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/biquad_filter.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
namespace webrtc {
namespace rnn_vad {
// Feature extractor to feed the VAD RNN.
class FeaturesExtractor {
public:
FeaturesExtractor();
FeaturesExtractor(const FeaturesExtractor&) = delete;
FeaturesExtractor& operator=(const FeaturesExtractor&) = delete;
~FeaturesExtractor();
void Reset();
// Analyzes the samples, computes the feature vector and returns true if
// silence is detected (false if not). When silence is detected,
// |feature_vector| is partially written and therefore must not be used to
// feed the VAD RNN.
bool CheckSilenceComputeFeatures(
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
rtc::ArrayView<float, kFeatureVectorSize> feature_vector);
private:
const bool use_high_pass_filter_;
// TODO(bugs.webrtc.org/7494): Remove HPF depending on how AGC2 is used in APM
// and on whether an HPF is already used as pre-processing step in APM.
BiQuadFilter hpf_;
SequenceBuffer<float, kBufSize24kHz, kFrameSize10ms24kHz, kFrameSize20ms24kHz>
pitch_buf_24kHz_;
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf_24kHz_view_;
std::vector<float> lp_residual_;
rtc::ArrayView<float, kBufSize24kHz> lp_residual_view_;
PitchEstimator pitch_estimator_;
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
SpectralFeaturesExtractor spectral_features_extractor_;
PitchInfo pitch_info_48kHz_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_

View File

@ -0,0 +1,138 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
#include <algorithm>
#include <array>
#include <cmath>
#include <numeric>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// Computes cross-correlation coefficients between |x| and |y| and writes them
// in |x_corr|. The lag values are in {0, ..., max_lag - 1}, where max_lag
// equals the size of |x_corr|.
// The |x| and |y| sub-arrays used to compute a cross-correlation coefficients
// for a lag l have both size "size of |x| - l" - i.e., the longest sub-array is
// used. |x| and |y| must have the same size.
void ComputeCrossCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float, kNumLpcCoefficients> x_corr) {
constexpr size_t max_lag = x_corr.size();
RTC_DCHECK_EQ(x.size(), y.size());
RTC_DCHECK_LT(max_lag, x.size());
for (size_t lag = 0; lag < max_lag; ++lag) {
x_corr[lag] =
std::inner_product(x.begin(), x.end() - lag, y.begin() + lag, 0.f);
}
}
// Applies denoising to the auto-correlation coefficients.
void DenoiseAutoCorrelation(
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
// Assume -40 dB white noise floor.
auto_corr[0] *= 1.0001f;
for (size_t i = 1; i < kNumLpcCoefficients; ++i) {
auto_corr[i] -= auto_corr[i] * (0.008f * i) * (0.008f * i);
}
}
// Computes the initial inverse filter coefficients given the auto-correlation
// coefficients of an input frame.
void ComputeInitialInverseFilterCoefficients(
rtc::ArrayView<const float, kNumLpcCoefficients> auto_corr,
rtc::ArrayView<float, kNumLpcCoefficients - 1> lpc_coeffs) {
float error = auto_corr[0];
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
float reflection_coeff = 0.f;
for (size_t j = 0; j < i; ++j) {
reflection_coeff += lpc_coeffs[j] * auto_corr[i - j];
}
reflection_coeff += auto_corr[i + 1];
// Avoid division by numbers close to zero.
constexpr float kMinErrorMagnitude = 1e-6f;
if (std::fabs(error) < kMinErrorMagnitude) {
error = std::copysign(kMinErrorMagnitude, error);
}
reflection_coeff /= -error;
// Update LPC coefficients and total error.
lpc_coeffs[i] = reflection_coeff;
for (size_t j = 0; j<(i + 1)>> 1; ++j) {
const float tmp1 = lpc_coeffs[j];
const float tmp2 = lpc_coeffs[i - 1 - j];
lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2;
lpc_coeffs[i - 1 - j] = tmp2 + reflection_coeff * tmp1;
}
error -= reflection_coeff * reflection_coeff * error;
if (error < 0.001f * auto_corr[0]) {
break;
}
}
}
} // namespace
void ComputeAndPostProcessLpcCoefficients(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs) {
std::array<float, kNumLpcCoefficients> auto_corr;
ComputeCrossCorrelation(x, x, {auto_corr.data(), auto_corr.size()});
if (auto_corr[0] == 0.f) { // Empty frame.
std::fill(lpc_coeffs.begin(), lpc_coeffs.end(), 0);
return;
}
DenoiseAutoCorrelation({auto_corr.data(), auto_corr.size()});
std::array<float, kNumLpcCoefficients - 1> lpc_coeffs_pre{};
ComputeInitialInverseFilterCoefficients(auto_corr, lpc_coeffs_pre);
// LPC coefficients post-processing.
// TODO(bugs.webrtc.org/9076): Consider removing these steps.
float c1 = 1.f;
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
c1 *= 0.9f;
lpc_coeffs_pre[i] *= c1;
}
const float c2 = 0.8f;
lpc_coeffs[0] = lpc_coeffs_pre[0] + c2;
lpc_coeffs[1] = lpc_coeffs_pre[1] + c2 * lpc_coeffs_pre[0];
lpc_coeffs[2] = lpc_coeffs_pre[2] + c2 * lpc_coeffs_pre[1];
lpc_coeffs[3] = lpc_coeffs_pre[3] + c2 * lpc_coeffs_pre[2];
lpc_coeffs[4] = c2 * lpc_coeffs_pre[3];
}
void ComputeLpResidual(
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
rtc::ArrayView<const float> x,
rtc::ArrayView<float> y) {
RTC_DCHECK_LT(kNumLpcCoefficients, x.size());
RTC_DCHECK_EQ(x.size(), y.size());
std::array<float, kNumLpcCoefficients> input_chunk;
input_chunk.fill(0.f);
for (size_t i = 0; i < y.size(); ++i) {
const float sum = std::inner_product(input_chunk.begin(), input_chunk.end(),
lpc_coeffs.begin(), x[i]);
// Circular shift and add a new sample.
for (size_t j = kNumLpcCoefficients - 1; j > 0; --j)
input_chunk[j] = input_chunk[j - 1];
input_chunk[0] = x[i];
// Copy result.
y[i] = sum;
}
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,41 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
#include <stddef.h>
#include "api/array_view.h"
namespace webrtc {
namespace rnn_vad {
// LPC inverse filter length.
constexpr size_t kNumLpcCoefficients = 5;
// Given a frame |x|, computes a post-processed version of LPC coefficients
// tailored for pitch estimation.
void ComputeAndPostProcessLpcCoefficients(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs);
// Computes the LP residual for the input frame |x| and the LPC coefficients
// |lpc_coeffs|. |y| and |x| can point to the same array for in-place
// computation.
void ComputeLpResidual(
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
rtc::ArrayView<const float> x,
rtc::ArrayView<float> y);
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
namespace webrtc {
namespace rnn_vad {
// Stores pitch period and gain information. The pitch gain measures the
// strength of the pitch (the higher, the stronger).
struct PitchInfo {
PitchInfo() : period(0), gain(0.f) {}
PitchInfo(int p, float g) : period(p), gain(g) {}
int period;
float gain;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_

View File

@ -0,0 +1,56 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include <array>
#include <cstddef>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
PitchEstimator::PitchEstimator()
: pitch_buf_decimated_(kBufSize12kHz),
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
auto_corr_(kNumInvertedLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size());
}
PitchEstimator::~PitchEstimator() = default;
PitchInfo PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
// Perform the initial pitch search at 12 kHz.
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
auto_corr_view_);
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
// Refine the pitch period estimation.
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
// to 24 kHz.
pitch_candidates_inv_lags[0] *= 2;
pitch_candidates_inv_lags[1] *= 2;
size_t pitch_inv_lag_48kHz =
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags);
// Look for stronger harmonics to find the final pitch period and its gain.
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
return last_pitch_48kHz_;
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,49 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
namespace webrtc {
namespace rnn_vad {
// Pitch estimator.
class PitchEstimator {
public:
PitchEstimator();
PitchEstimator(const PitchEstimator&) = delete;
PitchEstimator& operator=(const PitchEstimator&) = delete;
~PitchEstimator();
// Estimates the pitch period and gain. Returns the pitch estimation data for
// 48 kHz.
PitchInfo Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf);
private:
PitchInfo last_pitch_48kHz_;
AutoCorrelationCalculator auto_corr_calculator_;
std::vector<float> pitch_buf_decimated_;
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
std::vector<float> auto_corr_;
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr_view_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_

View File

@ -0,0 +1,403 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <numeric>
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// Converts a lag to an inverted lag (only for 24kHz).
size_t GetInvertedLag(size_t lag) {
RTC_DCHECK_LE(lag, kMaxPitch24kHz);
return kMaxPitch24kHz - lag;
}
float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,
size_t inv_lag,
size_t max_pitch_period) {
RTC_DCHECK_LT(inv_lag, pitch_buf.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
RTC_DCHECK_LE(inv_lag, max_pitch_period);
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
return std::inner_product(pitch_buf.begin() + max_pitch_period,
pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f);
}
// Computes a pseudo-interpolation offset for an estimated pitch period |lag| by
// looking at the auto-correlation coefficients in the neighborhood of |lag|.
// (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output
// is a lag in {-1, 0, +1}.
// TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it
// is relevant only if the spectral analysis works at a sample rate that is
// twice as that of the pitch buffer (not so important instead for the estimated
// pitch period feature fed into the RNN).
int GetPitchPseudoInterpolationOffset(size_t lag,
float prev_auto_corr,
float lag_auto_corr,
float next_auto_corr) {
const float& a = prev_auto_corr;
const float& b = lag_auto_corr;
const float& c = next_auto_corr;
int offset = 0;
if ((c - a) > 0.7f * (b - a)) {
offset = 1; // |c| is the largest auto-correlation coefficient.
} else if ((a - c) > 0.7f * (b - c)) {
offset = -1; // |a| is the largest auto-correlation coefficient.
}
return offset;
}
// Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The
// output sample rate is twice as that of |lag|.
size_t PitchPseudoInterpolationLagPitchBuf(
size_t lag,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (lag > 0 && lag < kMaxPitch24kHz) {
offset = GetPitchPseudoInterpolationOffset(
lag,
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1),
kMaxPitch24kHz),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag),
kMaxPitch24kHz),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1),
kMaxPitch24kHz));
}
return 2 * lag + offset;
}
// Refines a pitch period |inv_lag| encoded as inverted lag with
// pseudo-interpolation. The output sample rate is twice as that of
// |inv_lag|.
size_t PitchPseudoInterpolationInvLagAutoCorr(
size_t inv_lag,
rtc::ArrayView<const float> auto_corr) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) {
offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1],
auto_corr[inv_lag],
auto_corr[inv_lag - 1]);
}
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
// be subtracted since |inv_lag| is an inverted lag but offset is a lag.
return 2 * inv_lag + offset;
}
// Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when
// looking for sub-harmonics.
// The values have been chosen to serve the following algorithm. Given the
// initial pitch period T, we examine whether one of its harmonics is the true
// fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of
// these harmonics, in addition to the pitch gain of itself, we choose one
// multiple of its pitch period, n*T/k, to validate it (by averaging their pitch
// gains). The multiplier n is chosen so that n*T/k is used only one time over
// all k. When for example k = 4, we should also expect a peak at 3*T/4. When
// k = 8 instead we don't want to look at 2*T/8, since we have already checked
// T/4 before. Instead, we look at T*3/8.
// The array can be generate in Python as follows:
// from fractions import Fraction
// # Smallest positive integer not in X.
// def mex(X):
// for i in range(1, int(max(X)+2)):
// if i not in X:
// return i
// # Visited multiples of the period.
// S = {1}
// for n in range(2, 16):
// sn = mex({n * i for i in S} | {1})
// S = S | {Fraction(1, n), Fraction(sn, n)}
// print(sn, end=', ')
constexpr std::array<int, 14> kSubHarmonicMultipliers = {
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
} // namespace
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst) {
// TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
static_assert(2 * dst.size() == src.size(), "");
for (size_t i = 0; i < dst.size(); ++i) {
dst[i] = src[2 * i];
}
}
float ComputePitchGainThreshold(int candidate_pitch_period,
int pitch_period_ratio,
int initial_pitch_period,
float initial_pitch_gain,
int prev_pitch_period,
float prev_pitch_gain) {
// Map arguments to more compact aliases.
const int& t1 = candidate_pitch_period;
const int& k = pitch_period_ratio;
const int& t0 = initial_pitch_period;
const float& g0 = initial_pitch_gain;
const int& t_prev = prev_pitch_period;
const float& g_prev = prev_pitch_gain;
// Validate input.
RTC_DCHECK_GE(t1, 0);
RTC_DCHECK_GE(k, 2);
RTC_DCHECK_GE(t0, 0);
RTC_DCHECK_GE(t_prev, 0);
// Compute a term that lowers the threshold when |t1| is close to the last
// estimated period |t_prev| - i.e., pitch tracking.
float lower_threshold_term = 0;
if (abs(t1 - t_prev) <= 1) {
// The candidate pitch period is within 1 sample from the previous one.
// Make the candidate at |t1| very easy to be accepted.
lower_threshold_term = g_prev;
} else if (abs(t1 - t_prev) == 2 &&
t0 > kInitialPitchPeriodThresholds[k - 2]) {
// The candidate pitch period is 2 samples far from the previous one and the
// period |t0| (from which |t1| has been derived) is greater than a
// threshold. Make |t1| easy to be accepted.
lower_threshold_term = 0.5f * g_prev;
}
// Set the threshold based on the gain of the initial estimate |t0|. Also
// reduce the chance of false positives caused by a bias towards high
// frequencies (originating from short-term correlations).
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
// High frequency.
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
} else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
// Even higher frequency.
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
}
return threshold;
}
void ComputeSlidingFrameSquareEnergies(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) {
float yy =
ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz);
yy_values[0] = yy;
for (size_t i = 1; i < yy_values.size(); ++i) {
RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz);
RTC_DCHECK_LE(i, kMaxPitch24kHz);
const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i];
const float new_coeff = pitch_buf[kMaxPitch24kHz - i];
yy -= old_coeff * old_coeff;
yy += new_coeff * new_coeff;
yy = std::max(0.f, yy);
yy_values[i] = yy;
}
}
std::array<size_t, 2> FindBestPitchPeriods(
rtc::ArrayView<const float> auto_corr,
rtc::ArrayView<const float> pitch_buf,
size_t max_pitch_period) {
// Stores a pitch candidate period and strength information.
struct PitchCandidate {
// Pitch period encoded as inverted lag.
size_t period_inverted_lag = 0;
// Pitch strength encoded as a ratio.
float strength_numerator = -1.f;
float strength_denominator = 0.f;
// Compare the strength of two pitch candidates.
bool HasStrongerPitchThan(const PitchCandidate& b) const {
// Comparing the numerator/denominator ratios without using divisions.
return strength_numerator * b.strength_denominator >
b.strength_numerator * strength_denominator;
}
};
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
const size_t frame_size = pitch_buf.size() - max_pitch_period;
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
float yy =
std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1,
pitch_buf.begin(), 1.f);
// Search best and second best pitches by looking at the scaled
// auto-correlation.
PitchCandidate candidate;
PitchCandidate best;
PitchCandidate second_best;
second_best.period_inverted_lag = 1;
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
// A pitch candidate must have positive correlation.
if (auto_corr[inv_lag] > 0) {
candidate.period_inverted_lag = inv_lag;
candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag];
candidate.strength_denominator = yy;
if (candidate.HasStrongerPitchThan(second_best)) {
if (candidate.HasStrongerPitchThan(best)) {
second_best = best;
best = candidate;
} else {
second_best = candidate;
}
}
}
// Update |squared_energy_y| for the next inverted lag.
const float old_coeff = pitch_buf[inv_lag];
const float new_coeff = pitch_buf[inv_lag + frame_size];
yy -= old_coeff * old_coeff;
yy += new_coeff * new_coeff;
yy = std::max(0.f, yy);
}
return {{best.period_inverted_lag, second_best.period_inverted_lag}};
}
size_t RefinePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<const size_t, 2> inv_lags) {
// Compute the auto-correlation terms only for neighbors of the given pitch
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
// for a few lag values).
std::array<float, kNumInvertedLags24kHz> auto_corr;
auto_corr.fill(0.f); // Zeros become ignored lags in FindBestPitchPeriods().
auto is_neighbor = [](size_t i, size_t j) {
return ((i > j) ? (i - j) : (j - i)) <= 2;
};
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1]))
auto_corr[inv_lag] =
ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz);
}
// Find best pitch at 24 kHz.
const auto pitch_candidates_inv_lags = FindBestPitchPeriods(
{auto_corr.data(), auto_corr.size()},
{pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz);
const auto inv_lag = pitch_candidates_inv_lags[0]; // Refine the best.
// Pseudo-interpolation.
return PitchPseudoInterpolationInvLagAutoCorr(inv_lag, auto_corr);
}
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz) {
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
// Stores information for a refined pitch candidate.
struct RefinedPitchCandidate {
RefinedPitchCandidate() {}
RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
: period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
int period_24kHz;
// Pitch strength information.
float gain;
// Additional pitch strength information used for the final estimation of
// pitch gain.
float xy; // Cross-correlation.
float yy; // Auto-correlation.
};
// Initialize.
std::array<float, kMaxPitch24kHz + 1> yy_values;
ComputeSlidingFrameSquareEnergies(pitch_buf,
{yy_values.data(), yy_values.size()});
const float xx = yy_values[0];
// Helper lambdas.
const auto pitch_gain = [](float xy, float yy, float xx) {
RTC_DCHECK_LE(0.f, xx * yy);
return xy / std::sqrt(1.f + xx * yy);
};
// Initial pitch candidate gain.
RefinedPitchCandidate best_pitch;
best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
static_cast<int>(kMaxPitch24kHz - 1));
best_pitch.xy = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
best_pitch.yy = yy_values[best_pitch.period_24kHz];
best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx);
// Store the initial pitch period information.
const size_t initial_pitch_period = best_pitch.period_24kHz;
const float initial_pitch_gain = best_pitch.gain;
// Given the initial pitch estimation, check lower periods (i.e., harmonics).
const auto alternative_period = [](int period, int k, int n) -> int {
RTC_DCHECK_GT(k, 0);
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
};
for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
++k) {
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
break;
}
// When looking at |candidate_pitch_period|, we also look at one of its
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
// be greater than the maximum pitch period.
int candidate_pitch_secondary_period = alternative_period(
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
if (k == 2 &&
candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
candidate_pitch_secondary_period = initial_pitch_period;
}
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
<< "The lower pitch period and the additional sub-harmonic must not "
"coincide.";
// Compute an auto-correlation score for the primary pitch candidate
// |candidate_pitch_period| by also looking at its possible sub-harmonic
// |candidate_pitch_secondary_period|.
float xy_primary_period = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz);
float xy_secondary_period = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(candidate_pitch_secondary_period),
kMaxPitch24kHz);
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
float yy = 0.5f * (yy_values[candidate_pitch_period] +
yy_values[candidate_pitch_secondary_period]);
float candidate_pitch_gain = pitch_gain(xy, yy, xx);
// Maybe update best period.
float threshold = ComputePitchGainThreshold(
candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain,
prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain);
if (candidate_pitch_gain > threshold) {
best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy};
}
}
// Final pitch gain and period.
best_pitch.xy = std::max(0.f, best_pitch.xy);
RTC_DCHECK_LE(0.f, best_pitch.yy);
float final_pitch_gain = (best_pitch.yy <= best_pitch.xy)
? 1.f
: best_pitch.xy / (best_pitch.yy + 1.f);
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
int final_pitch_period_48kHz = std::max(
kMinPitch48kHz,
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
return {final_pitch_period_48kHz, final_pitch_gain};
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,77 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
#include <stddef.h>
#include <array>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
namespace webrtc {
namespace rnn_vad {
// Performs 2x decimation without any anti-aliasing filter.
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst);
// Computes a gain threshold for a candidate pitch period given the initial and
// the previous pitch period and gain estimates and the pitch period ratio used
// to derive the candidate pitch period from the initial period.
float ComputePitchGainThreshold(int candidate_pitch_period,
int pitch_period_ratio,
int initial_pitch_period,
float initial_pitch_gain,
int prev_pitch_period,
float prev_pitch_gain);
// Computes the sum of squared samples for every sliding frame in the pitch
// buffer. |yy_values| indexes are lags.
//
// The pitch buffer is structured as depicted below:
// |.........|...........|
// a b
// The part on the left, named "a" contains the oldest samples, whereas "b" the
// most recent ones. The size of "a" corresponds to the maximum pitch period,
// that of "b" to the frame size (e.g., 16 ms and 20 ms respectively).
void ComputeSlidingFrameSquareEnergies(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values);
// Given the auto-correlation coefficients stored according to
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
// and the second best pitch periods.
std::array<size_t, 2> FindBestPitchPeriods(
rtc::ArrayView<const float> auto_corr,
rtc::ArrayView<const float> pitch_buf,
size_t max_pitch_period);
// Refines the pitch period estimation given the pitch buffer |pitch_buf| and
// the initial pitch period estimation |inv_lags|. Returns an inverted lag at
// 48 kHz.
size_t RefinePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<const size_t, 2> inv_lags);
// Refines the pitch period estimation and compute the pitch gain. Returns the
// refined pitch estimation data at 48 kHz.
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz);
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
#include <array>
#include <cstring>
#include <type_traits>
#include "api/array_view.h"
namespace webrtc {
namespace rnn_vad {
// Ring buffer for N arrays of type T each one with size S.
template <typename T, size_t S, size_t N>
class RingBuffer {
static_assert(S > 0, "");
static_assert(N > 0, "");
static_assert(std::is_arithmetic<T>::value,
"Integral or floating point required.");
public:
RingBuffer() : tail_(0) {}
RingBuffer(const RingBuffer&) = delete;
RingBuffer& operator=(const RingBuffer&) = delete;
~RingBuffer() = default;
// Set the ring buffer values to zero.
void Reset() { buffer_.fill(0); }
// Replace the least recently pushed array in the buffer with |new_values|.
void Push(rtc::ArrayView<const T, S> new_values) {
std::memcpy(buffer_.data() + S * tail_, new_values.data(), S * sizeof(T));
tail_ += 1;
if (tail_ == N)
tail_ = 0;
}
// Return an array view onto the array with a given delay. A view on the last
// and least recently push array is returned when |delay| is 0 and N - 1
// respectively.
rtc::ArrayView<const T, S> GetArrayView(size_t delay) const {
const int delay_int = static_cast<int>(delay);
RTC_DCHECK_LE(0, delay_int);
RTC_DCHECK_LT(delay_int, N);
int offset = tail_ - 1 - delay_int;
if (offset < 0)
offset += N;
return {buffer_.data() + S * offset, S};
}
private:
int tail_; // Index of the least recently pushed sub-array.
std::array<T, S * N> buffer_{};
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_

View File

@ -0,0 +1,425 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
#include "rtc_base/system/arch.h"
#if defined(WEBRTC_HAS_NEON)
#include <arm_neon.h>
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
#include <algorithm>
#include <array>
#include <cmath>
#include <numeric>
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "third_party/rnnoise/src/rnn_activations.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
namespace {
using rnnoise::kWeightsScale;
using rnnoise::kInputLayerInputSize;
static_assert(kFeatureVectorSize == kInputLayerInputSize, "");
using rnnoise::kInputDenseBias;
using rnnoise::kInputDenseWeights;
using rnnoise::kInputLayerOutputSize;
static_assert(kInputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
"Increase kFullyConnectedLayersMaxUnits.");
using rnnoise::kHiddenGruBias;
using rnnoise::kHiddenGruRecurrentWeights;
using rnnoise::kHiddenGruWeights;
using rnnoise::kHiddenLayerOutputSize;
static_assert(kHiddenLayerOutputSize <= kRecurrentLayersMaxUnits,
"Increase kRecurrentLayersMaxUnits.");
using rnnoise::kOutputDenseBias;
using rnnoise::kOutputDenseWeights;
using rnnoise::kOutputLayerOutputSize;
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
"Increase kFullyConnectedLayersMaxUnits.");
using rnnoise::SigmoidApproximated;
using rnnoise::TansigApproximated;
inline float RectifiedLinearUnit(float x) {
return x < 0.f ? 0.f : x;
}
std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
std::vector<float> scaled_params(params.size());
std::transform(params.begin(), params.end(), scaled_params.begin(),
[](int8_t x) -> float {
return rnnoise::kWeightsScale * static_cast<float>(x);
});
return scaled_params;
}
// TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this
// function to improve setup time.
// Casts and scales |weights| and re-arranges the layout.
std::vector<float> GetPreprocessedFcWeights(
rtc::ArrayView<const int8_t> weights,
size_t output_size) {
if (output_size == 1) {
return GetScaledParams(weights);
}
// Transpose, scale and cast.
const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size);
std::vector<float> w(weights.size());
for (size_t o = 0; o < output_size; ++o) {
for (size_t i = 0; i < input_size; ++i) {
w[o * input_size + i] = rnnoise::kWeightsScale *
static_cast<float>(weights[i * output_size + o]);
}
}
return w;
}
constexpr size_t kNumGruGates = 3; // Update, reset, output.
// TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this
// function to improve setup time.
// Casts and scales |tensor_src| for a GRU layer and re-arranges the layout.
// It works both for weights, recurrent weights and bias.
std::vector<float> GetPreprocessedGruTensor(
rtc::ArrayView<const int8_t> tensor_src,
size_t output_size) {
// Transpose, cast and scale.
// |n| is the size of the first dimension of the 3-dim tensor |weights|.
const size_t n =
rtc::CheckedDivExact(tensor_src.size(), output_size * kNumGruGates);
const size_t stride_src = kNumGruGates * output_size;
const size_t stride_dst = n * output_size;
std::vector<float> tensor_dst(tensor_src.size());
for (size_t g = 0; g < kNumGruGates; ++g) {
for (size_t o = 0; o < output_size; ++o) {
for (size_t i = 0; i < n; ++i) {
tensor_dst[g * stride_dst + o * n + i] =
rnnoise::kWeightsScale *
static_cast<float>(
tensor_src[i * stride_src + g * output_size + o]);
}
}
}
return tensor_dst;
}
void ComputeGruUpdateResetGates(size_t input_size,
size_t output_size,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> state,
rtc::ArrayView<float> gate) {
for (size_t o = 0; o < output_size; ++o) {
gate[o] = bias[o];
for (size_t i = 0; i < input_size; ++i) {
gate[o] += input[i] * weights[o * input_size + i];
}
for (size_t s = 0; s < output_size; ++s) {
gate[o] += state[s] * recurrent_weights[o * output_size + s];
}
gate[o] = SigmoidApproximated(gate[o]);
}
}
void ComputeGruOutputGate(size_t input_size,
size_t output_size,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> state,
rtc::ArrayView<const float> reset,
rtc::ArrayView<float> gate) {
for (size_t o = 0; o < output_size; ++o) {
gate[o] = bias[o];
for (size_t i = 0; i < input_size; ++i) {
gate[o] += input[i] * weights[o * input_size + i];
}
for (size_t s = 0; s < output_size; ++s) {
gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
}
gate[o] = RectifiedLinearUnit(gate[o]);
}
}
// Gated recurrent unit (GRU) layer un-optimized implementation.
void ComputeGruLayerOutput(size_t input_size,
size_t output_size,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<const float> bias,
rtc::ArrayView<float> state) {
RTC_DCHECK_EQ(input_size, input.size());
// Stride and offset used to read parameter arrays.
const size_t stride_in = input_size * output_size;
const size_t stride_out = output_size * output_size;
// Update gate.
std::array<float, kRecurrentLayersMaxUnits> update;
ComputeGruUpdateResetGates(
input_size, output_size, weights.subview(0, stride_in),
recurrent_weights.subview(0, stride_out), bias.subview(0, output_size),
input, state, update);
// Reset gate.
std::array<float, kRecurrentLayersMaxUnits> reset;
ComputeGruUpdateResetGates(
input_size, output_size, weights.subview(stride_in, stride_in),
recurrent_weights.subview(stride_out, stride_out),
bias.subview(output_size, output_size), input, state, reset);
// Output gate.
std::array<float, kRecurrentLayersMaxUnits> output;
ComputeGruOutputGate(
input_size, output_size, weights.subview(2 * stride_in, stride_in),
recurrent_weights.subview(2 * stride_out, stride_out),
bias.subview(2 * output_size, output_size), input, state, reset, output);
// Update output through the update gates and update the state.
for (size_t o = 0; o < output_size; ++o) {
output[o] = update[o] * state[o] + (1.f - update[o]) * output[o];
state[o] = output[o];
}
}
// Fully connected layer un-optimized implementation.
void ComputeFullyConnectedLayerOutput(
size_t input_size,
size_t output_size,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
rtc::FunctionView<float(float)> activation_function,
rtc::ArrayView<float> output) {
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
for (size_t o = 0; o < output_size; ++o) {
output[o] = bias[o];
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
// |weights_| change the performance across different platforms.
for (size_t i = 0; i < input_size; ++i) {
output[o] += input[i] * weights[o * input_size + i];
}
output[o] = activation_function(output[o]);
}
}
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Fully connected layer SSE2 implementation.
void ComputeFullyConnectedLayerOutputSse2(
size_t input_size,
size_t output_size,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
rtc::FunctionView<float(float)> activation_function,
rtc::ArrayView<float> output) {
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
const size_t input_size_by_4 = input_size >> 2;
const size_t offset = input_size & ~3;
__m128 sum_wx_128;
const float* v = reinterpret_cast<const float*>(&sum_wx_128);
for (size_t o = 0; o < output_size; ++o) {
// Perform 128 bit vector operations.
sum_wx_128 = _mm_set1_ps(0);
const float* x_p = input.data();
const float* w_p = weights.data() + o * input_size;
for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
sum_wx_128 = _mm_add_ps(sum_wx_128,
_mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p)));
}
// Perform non-vector operations for any remaining items, sum up bias term
// and results from the vectorized code, and apply the activation function.
output[o] = activation_function(
std::inner_product(input.begin() + offset, input.end(),
weights.begin() + o * input_size + offset,
bias[o] + v[0] + v[1] + v[2] + v[3]));
}
}
#endif
} // namespace
FullyConnectedLayer::FullyConnectedLayer(
const size_t input_size,
const size_t output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
rtc::FunctionView<float(float)> activation_function,
Optimization optimization)
: input_size_(input_size),
output_size_(output_size),
bias_(GetScaledParams(bias)),
weights_(GetPreprocessedFcWeights(weights, output_size)),
activation_function_(activation_function),
optimization_(optimization) {
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
<< "Static over-allocation of fully-connected layers output vectors is "
"not sufficient.";
RTC_DCHECK_EQ(output_size_, bias_.size())
<< "Mismatching output size and bias terms array size.";
RTC_DCHECK_EQ(input_size_ * output_size_, weights_.size())
<< "Mismatching input-output size and weight coefficients array size.";
}
FullyConnectedLayer::~FullyConnectedLayer() = default;
rtc::ArrayView<const float> FullyConnectedLayer::GetOutput() const {
return rtc::ArrayView<const float>(output_.data(), output_size_);
}
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Optimization::kSse2:
ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input,
bias_, weights_,
activation_function_, output_);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Optimization::kNeon:
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
weights_, activation_function_, output_);
break;
#endif
default:
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
weights_, activation_function_, output_);
}
}
GatedRecurrentLayer::GatedRecurrentLayer(
const size_t input_size,
const size_t output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
const rtc::ArrayView<const int8_t> recurrent_weights,
Optimization optimization)
: input_size_(input_size),
output_size_(output_size),
bias_(GetPreprocessedGruTensor(bias, output_size)),
weights_(GetPreprocessedGruTensor(weights, output_size)),
recurrent_weights_(
GetPreprocessedGruTensor(recurrent_weights, output_size)),
optimization_(optimization) {
RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
<< "Static over-allocation of recurrent layers state vectors is not "
"sufficient.";
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
<< "Mismatching output size and bias terms array size.";
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
<< "Mismatching input-output size and weight coefficients array size.";
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
recurrent_weights_.size())
<< "Mismatching input-output size and recurrent weight coefficients array"
" size.";
Reset();
}
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
rtc::ArrayView<const float> GatedRecurrentLayer::GetOutput() const {
return rtc::ArrayView<const float>(state_.data(), output_size_);
}
void GatedRecurrentLayer::Reset() {
state_.fill(0.f);
}
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Optimization::kSse2:
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
recurrent_weights_, bias_, state_);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Optimization::kNeon:
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
recurrent_weights_, bias_, state_);
break;
#endif
default:
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
recurrent_weights_, bias_, state_);
}
}
RnnBasedVad::RnnBasedVad()
: input_layer_(kInputLayerInputSize,
kInputLayerOutputSize,
kInputDenseBias,
kInputDenseWeights,
TansigApproximated,
DetectOptimization()),
hidden_layer_(kInputLayerOutputSize,
kHiddenLayerOutputSize,
kHiddenGruBias,
kHiddenGruWeights,
kHiddenGruRecurrentWeights,
DetectOptimization()),
output_layer_(kHiddenLayerOutputSize,
kOutputLayerOutputSize,
kOutputDenseBias,
kOutputDenseWeights,
SigmoidApproximated,
DetectOptimization()) {
// Input-output chaining size checks.
RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size())
<< "The input and the hidden layers sizes do not match.";
RTC_DCHECK_EQ(hidden_layer_.output_size(), output_layer_.input_size())
<< "The hidden and the output layers sizes do not match.";
}
RnnBasedVad::~RnnBasedVad() = default;
void RnnBasedVad::Reset() {
hidden_layer_.Reset();
}
float RnnBasedVad::ComputeVadProbability(
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
bool is_silence) {
if (is_silence) {
Reset();
return 0.f;
}
input_layer_.ComputeOutput(feature_vector);
hidden_layer_.ComputeOutput(input_layer_.GetOutput());
output_layer_.ComputeOutput(hidden_layer_.GetOutput());
const auto vad_output = output_layer_.GetOutput();
return vad_output[0];
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,126 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
#include <stddef.h>
#include <sys/types.h>
#include <array>
#include <vector>
#include "api/array_view.h"
#include "api/function_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/system/arch.h"
namespace webrtc {
namespace rnn_vad {
// Maximum number of units for a fully-connected layer. This value is used to
// over-allocate space for fully-connected layers output vectors (implemented as
// std::array). The value should equal the number of units of the largest
// fully-connected layer.
constexpr size_t kFullyConnectedLayersMaxUnits = 24;
// Maximum number of units for a recurrent layer. This value is used to
// over-allocate space for recurrent layers state vectors (implemented as
// std::array). The value should equal the number of units of the largest
// recurrent layer.
constexpr size_t kRecurrentLayersMaxUnits = 24;
// Fully-connected layer.
class FullyConnectedLayer {
public:
FullyConnectedLayer(size_t input_size,
size_t output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
rtc::FunctionView<float(float)> activation_function,
Optimization optimization);
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
~FullyConnectedLayer();
size_t input_size() const { return input_size_; }
size_t output_size() const { return output_size_; }
Optimization optimization() const { return optimization_; }
rtc::ArrayView<const float> GetOutput() const;
// Computes the fully-connected layer output.
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const size_t input_size_;
const size_t output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
rtc::FunctionView<float(float)> activation_function_;
// The output vector of a recurrent layer has length equal to |output_size_|.
// However, for efficiency, over-allocation is used.
std::array<float, kFullyConnectedLayersMaxUnits> output_;
const Optimization optimization_;
};
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
// activation functions for the update/reset and output gates respectively.
class GatedRecurrentLayer {
public:
GatedRecurrentLayer(size_t input_size,
size_t output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
rtc::ArrayView<const int8_t> recurrent_weights,
Optimization optimization);
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
~GatedRecurrentLayer();
size_t input_size() const { return input_size_; }
size_t output_size() const { return output_size_; }
Optimization optimization() const { return optimization_; }
rtc::ArrayView<const float> GetOutput() const;
void Reset();
// Computes the recurrent layer output and updates the status.
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const size_t input_size_;
const size_t output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
const std::vector<float> recurrent_weights_;
// The state vector of a recurrent layer has length equal to |output_size_|.
// However, to avoid dynamic allocation, over-allocation is used.
std::array<float, kRecurrentLayersMaxUnits> state_;
const Optimization optimization_;
};
// Recurrent network based VAD.
class RnnBasedVad {
public:
RnnBasedVad();
RnnBasedVad(const RnnBasedVad&) = delete;
RnnBasedVad& operator=(const RnnBasedVad&) = delete;
~RnnBasedVad();
void Reset();
// Compute and returns the probability of voice (range: [0.0, 1.0]).
float ComputeVadProbability(
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
bool is_silence);
private:
FullyConnectedLayer input_layer_;
GatedRecurrentLayer hidden_layer_;
FullyConnectedLayer output_layer_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_

View File

@ -0,0 +1,120 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <array>
#include <string>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "common_audio/resampler/push_sinc_resampler.h"
#include "common_audio/wav_file.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "rtc_base/logging.h"
ABSL_FLAG(std::string, i, "", "Path to the input wav file");
ABSL_FLAG(std::string, f, "", "Path to the output features file");
ABSL_FLAG(std::string, o, "", "Path to the output VAD probabilities file");
namespace webrtc {
namespace rnn_vad {
namespace test {
int main(int argc, char* argv[]) {
absl::ParseCommandLine(argc, argv);
rtc::LogMessage::LogToDebug(rtc::LS_INFO);
// Open wav input file and check properties.
const std::string input_wav_file = absl::GetFlag(FLAGS_i);
WavReader wav_reader(input_wav_file);
if (wav_reader.num_channels() != 1) {
RTC_LOG(LS_ERROR) << "Only mono wav files are supported";
return 1;
}
if (wav_reader.sample_rate() % 100 != 0) {
RTC_LOG(LS_ERROR) << "The sample rate rate must allow 10 ms frames.";
return 1;
}
RTC_LOG(LS_INFO) << "Input sample rate: " << wav_reader.sample_rate();
// Init output files.
const std::string output_vad_probs_file = absl::GetFlag(FLAGS_o);
FILE* vad_probs_file = fopen(output_vad_probs_file.c_str(), "wb");
FILE* features_file = nullptr;
const std::string output_feature_file = absl::GetFlag(FLAGS_f);
if (!output_feature_file.empty()) {
features_file = fopen(output_feature_file.c_str(), "wb");
}
// Initialize.
const size_t frame_size_10ms =
rtc::CheckedDivExact(wav_reader.sample_rate(), 100);
std::vector<float> samples_10ms;
samples_10ms.resize(frame_size_10ms);
std::array<float, kFrameSize10ms24kHz> samples_10ms_24kHz;
PushSincResampler resampler(frame_size_10ms, kFrameSize10ms24kHz);
FeaturesExtractor features_extractor;
std::array<float, kFeatureVectorSize> feature_vector;
RnnBasedVad rnn_vad;
// Compute VAD probabilities.
while (true) {
// Read frame at the input sample rate.
const auto read_samples =
wav_reader.ReadSamples(frame_size_10ms, samples_10ms.data());
if (read_samples < frame_size_10ms) {
break; // EOF.
}
// Resample input.
resampler.Resample(samples_10ms.data(), samples_10ms.size(),
samples_10ms_24kHz.data(), samples_10ms_24kHz.size());
// Extract features and feed the RNN.
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
samples_10ms_24kHz, feature_vector);
float vad_probability =
rnn_vad.ComputeVadProbability(feature_vector, is_silence);
// Write voice probability.
RTC_DCHECK_GE(vad_probability, 0.f);
RTC_DCHECK_GE(1.f, vad_probability);
fwrite(&vad_probability, sizeof(float), 1, vad_probs_file);
// Write features.
if (features_file) {
const float float_is_silence = is_silence ? 1.f : 0.f;
fwrite(&float_is_silence, sizeof(float), 1, features_file);
if (is_silence) {
// Do not write uninitialized values.
feature_vector.fill(0.f);
}
fwrite(feature_vector.data(), sizeof(float), kFeatureVectorSize,
features_file);
}
}
// Close output file(s).
fclose(vad_probs_file);
RTC_LOG(LS_INFO) << "VAD probabilities written to " << output_vad_probs_file;
if (features_file) {
fclose(features_file);
RTC_LOG(LS_INFO) << "features written to " << output_feature_file;
}
return 0;
}
} // namespace test
} // namespace rnn_vad
} // namespace webrtc
int main(int argc, char* argv[]) {
return webrtc::rnn_vad::test::main(argc, argv);
}

View File

@ -0,0 +1,79 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
#include <algorithm>
#include <cstring>
#include <type_traits>
#include <vector>
#include "api/array_view.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
// Linear buffer implementation to (i) push fixed size chunks of sequential data
// and (ii) view contiguous parts of the buffer. The buffer and the pushed
// chunks have size S and N respectively. For instance, when S = 2N the first
// half of the sequence buffer is replaced with its second half, and the new N
// values are written at the end of the buffer.
// The class also provides a view on the most recent M values, where 0 < M <= S
// and by default M = N.
template <typename T, size_t S, size_t N, size_t M = N>
class SequenceBuffer {
static_assert(N <= S,
"The new chunk size cannot be larger than the sequence buffer "
"size.");
static_assert(std::is_arithmetic<T>::value,
"Integral or floating point required.");
public:
SequenceBuffer() : buffer_(S) {
RTC_DCHECK_EQ(S, buffer_.size());
Reset();
}
SequenceBuffer(const SequenceBuffer&) = delete;
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
~SequenceBuffer() = default;
size_t size() const { return S; }
size_t chunks_size() const { return N; }
// Sets the sequence buffer values to zero.
void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); }
// Returns a view on the whole buffer.
rtc::ArrayView<const T, S> GetBufferView() const {
return {buffer_.data(), S};
}
// Returns a view on the M most recent values of the buffer.
rtc::ArrayView<const T, M> GetMostRecentValuesView() const {
static_assert(M <= S,
"The number of most recent values cannot be larger than the "
"sequence buffer size.");
return {buffer_.data() + S - M, M};
}
// Shifts left the buffer by N items and add new N items at the end.
void Push(rtc::ArrayView<const T, N> new_values) {
// Make space for the new values.
if (S > N)
std::memmove(buffer_.data(), buffer_.data() + N, (S - N) * sizeof(T));
// Copy the new values at the end of the buffer.
std::memcpy(buffer_.data() + S - N, new_values.data(), N * sizeof(T));
}
private:
std::vector<T> buffer_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_

View File

@ -0,0 +1,213 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
#include <algorithm>
#include <cmath>
#include <limits>
#include <numeric>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr float kSilenceThreshold = 0.04f;
// Computes the new cepstral difference stats and pushes them into the passed
// symmetric matrix buffer.
void UpdateCepstralDifferenceStats(
rtc::ArrayView<const float, kNumBands> new_cepstral_coeffs,
const RingBuffer<float, kNumBands, kCepstralCoeffsHistorySize>& ring_buf,
SymmetricMatrixBuffer<float, kCepstralCoeffsHistorySize>* sym_matrix_buf) {
RTC_DCHECK(sym_matrix_buf);
// Compute the new cepstral distance stats.
std::array<float, kCepstralCoeffsHistorySize - 1> distances;
for (size_t i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
const size_t delay = i + 1;
auto old_cepstral_coeffs = ring_buf.GetArrayView(delay);
distances[i] = 0.f;
for (size_t k = 0; k < kNumBands; ++k) {
const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k];
distances[i] += c * c;
}
}
// Push the new spectral distance stats into the symmetric matrix buffer.
sym_matrix_buf->Push(distances);
}
// Computes the first half of the Vorbis window.
std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
float scaling = 1.f) {
constexpr size_t kHalfSize = kFrameSize20ms24kHz / 2;
std::array<float, kHalfSize> half_window{};
for (size_t i = 0; i < kHalfSize; ++i) {
half_window[i] =
scaling *
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
std::sin(0.5 * kPi * (i + 0.5) / kHalfSize));
}
return half_window;
}
// Computes the forward FFT on a 20 ms frame to which a given window function is
// applied. The Fourier coefficient corresponding to the Nyquist frequency is
// set to zero (it is never used and this allows to simplify the code).
void ComputeWindowedForwardFft(
rtc::ArrayView<const float, kFrameSize20ms24kHz> frame,
const std::array<float, kFrameSize20ms24kHz / 2>& half_window,
Pffft::FloatBuffer* fft_input_buffer,
Pffft::FloatBuffer* fft_output_buffer,
Pffft* fft) {
RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
// Apply windowing.
auto in = fft_input_buffer->GetView();
for (size_t i = 0, j = kFrameSize20ms24kHz - 1; i < half_window.size();
++i, --j) {
in[i] = frame[i] * half_window[i];
in[j] = frame[j] * half_window[i];
}
fft->ForwardTransform(*fft_input_buffer, fft_output_buffer, /*ordered=*/true);
// Set the Nyquist frequency coefficient to zero.
auto out = fft_output_buffer->GetView();
out[1] = 0.f;
}
} // namespace
SpectralFeaturesExtractor::SpectralFeaturesExtractor()
: half_window_(ComputeScaledHalfVorbisWindow(
1.f / static_cast<float>(kFrameSize20ms24kHz))),
fft_(kFrameSize20ms24kHz, Pffft::FftType::kReal),
fft_buffer_(fft_.CreateBuffer()),
reference_frame_fft_(fft_.CreateBuffer()),
lagged_frame_fft_(fft_.CreateBuffer()),
dct_table_(ComputeDctTable()) {}
SpectralFeaturesExtractor::~SpectralFeaturesExtractor() = default;
void SpectralFeaturesExtractor::Reset() {
cepstral_coeffs_ring_buf_.Reset();
cepstral_diffs_buf_.Reset();
}
bool SpectralFeaturesExtractor::CheckSilenceComputeFeatures(
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
rtc::ArrayView<float, kNumBands - kNumLowerBands> higher_bands_cepstrum,
rtc::ArrayView<float, kNumLowerBands> average,
rtc::ArrayView<float, kNumLowerBands> first_derivative,
rtc::ArrayView<float, kNumLowerBands> second_derivative,
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
float* variability) {
// Compute the Opus band energies for the reference frame.
ComputeWindowedForwardFft(reference_frame, half_window_, fft_buffer_.get(),
reference_frame_fft_.get(), &fft_);
spectral_correlator_.ComputeAutoCorrelation(
reference_frame_fft_->GetConstView(), reference_frame_bands_energy_);
// Check if the reference frame has silence.
const float tot_energy =
std::accumulate(reference_frame_bands_energy_.begin(),
reference_frame_bands_energy_.end(), 0.f);
if (tot_energy < kSilenceThreshold) {
return true;
}
// Compute the Opus band energies for the lagged frame.
ComputeWindowedForwardFft(lagged_frame, half_window_, fft_buffer_.get(),
lagged_frame_fft_.get(), &fft_);
spectral_correlator_.ComputeAutoCorrelation(lagged_frame_fft_->GetConstView(),
lagged_frame_bands_energy_);
// Log of the band energies for the reference frame.
std::array<float, kNumBands> log_bands_energy;
ComputeSmoothedLogMagnitudeSpectrum(reference_frame_bands_energy_,
log_bands_energy);
// Reference frame cepstrum.
std::array<float, kNumBands> cepstrum;
ComputeDct(log_bands_energy, dct_table_, cepstrum);
// Ad-hoc correction terms for the first two cepstral coefficients.
cepstrum[0] -= 12.f;
cepstrum[1] -= 4.f;
// Update the ring buffer and the cepstral difference stats.
cepstral_coeffs_ring_buf_.Push(cepstrum);
UpdateCepstralDifferenceStats(cepstrum, cepstral_coeffs_ring_buf_,
&cepstral_diffs_buf_);
// Write the higher bands cepstral coefficients.
RTC_DCHECK_EQ(cepstrum.size() - kNumLowerBands, higher_bands_cepstrum.size());
std::copy(cepstrum.begin() + kNumLowerBands, cepstrum.end(),
higher_bands_cepstrum.begin());
// Compute and write remaining features.
ComputeAvgAndDerivatives(average, first_derivative, second_derivative);
ComputeNormalizedCepstralCorrelation(bands_cross_corr);
RTC_DCHECK(variability);
*variability = ComputeVariability();
return false;
}
void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
rtc::ArrayView<float, kNumLowerBands> average,
rtc::ArrayView<float, kNumLowerBands> first_derivative,
rtc::ArrayView<float, kNumLowerBands> second_derivative) const {
auto curr = cepstral_coeffs_ring_buf_.GetArrayView(0);
auto prev1 = cepstral_coeffs_ring_buf_.GetArrayView(1);
auto prev2 = cepstral_coeffs_ring_buf_.GetArrayView(2);
RTC_DCHECK_EQ(average.size(), first_derivative.size());
RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
RTC_DCHECK_LE(average.size(), curr.size());
for (size_t i = 0; i < average.size(); ++i) {
// Average, kernel: [1, 1, 1].
average[i] = curr[i] + prev1[i] + prev2[i];
// First derivative, kernel: [1, 0, - 1].
first_derivative[i] = curr[i] - prev2[i];
// Second derivative, Laplacian kernel: [1, -2, 1].
second_derivative[i] = curr[i] - 2 * prev1[i] + prev2[i];
}
}
void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr) {
spectral_correlator_.ComputeCrossCorrelation(
reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
bands_cross_corr_);
// Normalize.
for (size_t i = 0; i < bands_cross_corr_.size(); ++i) {
bands_cross_corr_[i] =
bands_cross_corr_[i] /
std::sqrt(0.001f + reference_frame_bands_energy_[i] *
lagged_frame_bands_energy_[i]);
}
// Cepstrum.
ComputeDct(bands_cross_corr_, dct_table_, bands_cross_corr);
// Ad-hoc correction terms for the first two cepstral coefficients.
bands_cross_corr[0] -= 1.3f;
bands_cross_corr[1] -= 0.9f;
}
float SpectralFeaturesExtractor::ComputeVariability() const {
// Compute cepstral variability score.
float variability = 0.f;
for (size_t delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
float min_dist = std::numeric_limits<float>::max();
for (size_t delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
if (delay1 == delay2) // The distance would be 0.
continue;
min_dist =
std::min(min_dist, cepstral_diffs_buf_.GetValue(delay1, delay2));
}
variability += min_dist;
}
// Normalize (based on training set stats).
// TODO(bugs.webrtc.org/10480): Isolate normalization from feature extraction.
return variability / kCepstralCoeffsHistorySize - 2.1f;
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,79 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
#include <array>
#include <cstddef>
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h"
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
#include "modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h"
#include "modules/audio_processing/utility/pffft_wrapper.h"
namespace webrtc {
namespace rnn_vad {
// Class to compute spectral features.
class SpectralFeaturesExtractor {
public:
SpectralFeaturesExtractor();
SpectralFeaturesExtractor(const SpectralFeaturesExtractor&) = delete;
SpectralFeaturesExtractor& operator=(const SpectralFeaturesExtractor&) =
delete;
~SpectralFeaturesExtractor();
// Resets the internal state of the feature extractor.
void Reset();
// Analyzes a pair of reference and lagged frames from the pitch buffer,
// detects silence and computes features. If silence is detected, the output
// is neither computed nor written.
bool CheckSilenceComputeFeatures(
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
rtc::ArrayView<float, kNumBands - kNumLowerBands> higher_bands_cepstrum,
rtc::ArrayView<float, kNumLowerBands> average,
rtc::ArrayView<float, kNumLowerBands> first_derivative,
rtc::ArrayView<float, kNumLowerBands> second_derivative,
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
float* variability);
private:
void ComputeAvgAndDerivatives(
rtc::ArrayView<float, kNumLowerBands> average,
rtc::ArrayView<float, kNumLowerBands> first_derivative,
rtc::ArrayView<float, kNumLowerBands> second_derivative) const;
void ComputeNormalizedCepstralCorrelation(
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr);
float ComputeVariability() const;
const std::array<float, kFrameSize20ms24kHz / 2> half_window_;
Pffft fft_;
std::unique_ptr<Pffft::FloatBuffer> fft_buffer_;
std::unique_ptr<Pffft::FloatBuffer> reference_frame_fft_;
std::unique_ptr<Pffft::FloatBuffer> lagged_frame_fft_;
SpectralCorrelator spectral_correlator_;
std::array<float, kOpusBands24kHz> reference_frame_bands_energy_;
std::array<float, kOpusBands24kHz> lagged_frame_bands_energy_;
std::array<float, kOpusBands24kHz> bands_cross_corr_;
const std::array<float, kNumBands * kNumBands> dct_table_;
RingBuffer<float, kNumBands, kCepstralCoeffsHistorySize>
cepstral_coeffs_ring_buf_;
SymmetricMatrixBuffer<float, kCepstralCoeffsHistorySize> cepstral_diffs_buf_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_

View File

@ -0,0 +1,187 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
#include <algorithm>
#include <cmath>
#include <cstddef>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// Weights for each FFT coefficient for each Opus band (Nyquist frequency
// excluded). The size of each band is specified in
// |kOpusScaleNumBins24kHz20ms|.
constexpr std::array<float, kFrameSize20ms24kHz / 2> kOpusBandWeights24kHz20ms =
{{
0.f, 0.25f, 0.5f, 0.75f, // Band 0
0.f, 0.25f, 0.5f, 0.75f, // Band 1
0.f, 0.25f, 0.5f, 0.75f, // Band 2
0.f, 0.25f, 0.5f, 0.75f, // Band 3
0.f, 0.25f, 0.5f, 0.75f, // Band 4
0.f, 0.25f, 0.5f, 0.75f, // Band 5
0.f, 0.25f, 0.5f, 0.75f, // Band 6
0.f, 0.25f, 0.5f, 0.75f, // Band 7
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
0.625f, 0.75f, 0.875f, // Band 8
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
0.625f, 0.75f, 0.875f, // Band 9
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
0.625f, 0.75f, 0.875f, // Band 10
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
0.625f, 0.75f, 0.875f, // Band 11
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
0.9375f, // Band 12
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
0.9375f, // Band 13
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
0.9375f, // Band 14
0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 15
0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 16
0.f, 0.03125f, 0.0625f, 0.09375f, 0.125f,
0.15625f, 0.1875f, 0.21875f, 0.25f, 0.28125f,
0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f,
0.46875f, 0.5f, 0.53125f, 0.5625f, 0.59375f,
0.625f, 0.65625f, 0.6875f, 0.71875f, 0.75f,
0.78125f, 0.8125f, 0.84375f, 0.875f, 0.90625f,
0.9375f, 0.96875f, // Band 17
0.f, 0.0208333f, 0.0416667f, 0.0625f, 0.0833333f,
0.104167f, 0.125f, 0.145833f, 0.166667f, 0.1875f,
0.208333f, 0.229167f, 0.25f, 0.270833f, 0.291667f,
0.3125f, 0.333333f, 0.354167f, 0.375f, 0.395833f,
0.416667f, 0.4375f, 0.458333f, 0.479167f, 0.5f,
0.520833f, 0.541667f, 0.5625f, 0.583333f, 0.604167f,
0.625f, 0.645833f, 0.666667f, 0.6875f, 0.708333f,
0.729167f, 0.75f, 0.770833f, 0.791667f, 0.8125f,
0.833333f, 0.854167f, 0.875f, 0.895833f, 0.916667f,
0.9375f, 0.958333f, 0.979167f // Band 18
}};
} // namespace
SpectralCorrelator::SpectralCorrelator()
: weights_(kOpusBandWeights24kHz20ms.begin(),
kOpusBandWeights24kHz20ms.end()) {}
SpectralCorrelator::~SpectralCorrelator() = default;
void SpectralCorrelator::ComputeAutoCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const {
ComputeCrossCorrelation(x, x, auto_corr);
}
void SpectralCorrelator::ComputeCrossCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const {
RTC_DCHECK_EQ(x.size(), kFrameSize20ms24kHz);
RTC_DCHECK_EQ(x.size(), y.size());
RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
size_t k = 0; // Next Fourier coefficient index.
cross_corr[0] = 0.f;
for (size_t i = 0; i < kOpusBands24kHz - 1; ++i) {
cross_corr[i + 1] = 0.f;
for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
const float tmp = weights_[k] * v;
cross_corr[i] += v - tmp;
cross_corr[i + 1] += tmp;
k++;
}
}
cross_corr[0] *= 2.f; // The first band only gets half contribution.
RTC_DCHECK_EQ(k, kFrameSize20ms24kHz / 2); // Nyquist coefficient never used.
}
void ComputeSmoothedLogMagnitudeSpectrum(
rtc::ArrayView<const float> bands_energy,
rtc::ArrayView<float, kNumBands> log_bands_energy) {
RTC_DCHECK_LE(bands_energy.size(), kNumBands);
constexpr float kOneByHundred = 1e-2f;
constexpr float kLogOneByHundred = -2.f;
// Init.
float log_max = kLogOneByHundred;
float follow = kLogOneByHundred;
const auto smooth = [&log_max, &follow](float x) {
x = std::max(log_max - 7.f, std::max(follow - 1.5f, x));
log_max = std::max(log_max, x);
follow = std::max(follow - 1.5f, x);
return x;
};
// Smoothing over the bands for which the band energy is defined.
for (size_t i = 0; i < bands_energy.size(); ++i) {
log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
}
// Smoothing over the remaining bands (zero energy).
for (size_t i = bands_energy.size(); i < kNumBands; ++i) {
log_bands_energy[i] = smooth(kLogOneByHundred);
}
}
std::array<float, kNumBands * kNumBands> ComputeDctTable() {
std::array<float, kNumBands * kNumBands> dct_table;
const double k = std::sqrt(0.5);
for (size_t i = 0; i < kNumBands; ++i) {
for (size_t j = 0; j < kNumBands; ++j)
dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
dct_table[i * kNumBands] *= k;
}
return dct_table;
}
void ComputeDct(rtc::ArrayView<const float> in,
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
rtc::ArrayView<float> out) {
// DCT scaling factor - i.e., sqrt(2 / kNumBands).
constexpr float kDctScalingFactor = 0.301511345f;
constexpr float kDctScalingFactorError =
kDctScalingFactor * kDctScalingFactor -
2.f / static_cast<float>(kNumBands);
static_assert(
(kDctScalingFactorError >= 0.f && kDctScalingFactorError < 1e-1f) ||
(kDctScalingFactorError < 0.f && kDctScalingFactorError > -1e-1f),
"kNumBands changed and kDctScalingFactor has not been updated.");
RTC_DCHECK_NE(in.data(), out.data()) << "In-place DCT is not supported.";
RTC_DCHECK_LE(in.size(), kNumBands);
RTC_DCHECK_LE(1, out.size());
RTC_DCHECK_LE(out.size(), in.size());
for (size_t i = 0; i < out.size(); ++i) {
out[i] = 0.f;
for (size_t j = 0; j < in.size(); ++j) {
out[i] += in[j] * dct_table[j * kNumBands + i];
}
// TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.
out[i] *= kDctScalingFactor;
}
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,100 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
#include <stddef.h>
#include <array>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
namespace webrtc {
namespace rnn_vad {
// At a sample rate of 24 kHz, the last 3 Opus bands are beyond the Nyquist
// frequency. However, band #19 gets the contributions from band #18 because
// of the symmetric triangular filter with peak response at 12 kHz.
constexpr size_t kOpusBands24kHz = 20;
static_assert(kOpusBands24kHz < kNumBands,
"The number of bands at 24 kHz must be less than those defined "
"in the Opus scale at 48 kHz.");
// Number of FFT frequency bins covered by each band in the Opus scale at a
// sample rate of 24 kHz for 20 ms frames.
// Declared here for unit testing.
constexpr std::array<int, kOpusBands24kHz - 1> GetOpusScaleNumBins24kHz20ms() {
return {4, 4, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, 16, 16, 16, 24, 24, 32, 48};
}
// TODO(bugs.webrtc.org/10480): Move to a separate file.
// Class to compute band-wise spectral features in the Opus perceptual scale
// for 20 ms frames sampled at 24 kHz. The analysis methods apply triangular
// filters with peak response at the each band boundary.
class SpectralCorrelator {
public:
// Ctor.
SpectralCorrelator();
SpectralCorrelator(const SpectralCorrelator&) = delete;
SpectralCorrelator& operator=(const SpectralCorrelator&) = delete;
~SpectralCorrelator();
// Computes the band-wise spectral auto-correlations.
// |x| must:
// - have size equal to |kFrameSize20ms24kHz|;
// - be encoded as vectors of interleaved real-complex FFT coefficients
// where x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
void ComputeAutoCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const;
// Computes the band-wise spectral cross-correlations.
// |x| and |y| must:
// - have size equal to |kFrameSize20ms24kHz|;
// - be encoded as vectors of interleaved real-complex FFT coefficients where
// x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
void ComputeCrossCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const;
private:
const std::vector<float> weights_; // Weights for each Fourier coefficient.
};
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
// spectral_features.cc. Given a vector of Opus-bands energy coefficients,
// computes the log magnitude spectrum applying smoothing both over time and
// over frequency. Declared here for unit testing.
void ComputeSmoothedLogMagnitudeSpectrum(
rtc::ArrayView<const float> bands_energy,
rtc::ArrayView<float, kNumBands> log_bands_energy);
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
// spectral_features.cc. Creates a DCT table for arrays having size equal to
// |kNumBands|. Declared here for unit testing.
std::array<float, kNumBands * kNumBands> ComputeDctTable();
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
// spectral_features.cc. Computes DCT for |in| given a pre-computed DCT table.
// In-place computation is not allowed and |out| can be smaller than |in| in
// order to only compute the first DCT coefficients. Declared here for unit
// testing.
void ComputeDct(rtc::ArrayView<const float> in,
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
rtc::ArrayView<float> out);
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_

View File

@ -0,0 +1,94 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
#include <algorithm>
#include <array>
#include <cstring>
#include <utility>
#include "api/array_view.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
// Data structure to buffer the results of pair-wise comparisons between items
// stored in a ring buffer. Every time that the oldest item is replaced in the
// ring buffer, the new one is compared to the remaining items in the ring
// buffer. The results of such comparisons need to be buffered and automatically
// removed when one of the two corresponding items that have been compared is
// removed from the ring buffer. It is assumed that the comparison is symmetric
// and that comparing an item with itself is not needed.
template <typename T, size_t S>
class SymmetricMatrixBuffer {
static_assert(S > 2, "");
public:
SymmetricMatrixBuffer() = default;
SymmetricMatrixBuffer(const SymmetricMatrixBuffer&) = delete;
SymmetricMatrixBuffer& operator=(const SymmetricMatrixBuffer&) = delete;
~SymmetricMatrixBuffer() = default;
// Sets the buffer values to zero.
void Reset() {
static_assert(std::is_arithmetic<T>::value,
"Integral or floating point required.");
buf_.fill(0);
}
// Pushes the results from the comparison between the most recent item and
// those that are still in the ring buffer. The first element in |values| must
// correspond to the comparison between the most recent item and the second
// most recent one in the ring buffer, whereas the last element in |values|
// must correspond to the comparison between the most recent item and the
// oldest one in the ring buffer.
void Push(rtc::ArrayView<T, S - 1> values) {
// Move the lower-right sub-matrix of size (S-2) x (S-2) one row up and one
// column left.
std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T));
// Copy new values in the last column in the right order.
for (size_t i = 0; i < values.size(); ++i) {
const size_t index = (S - 1 - i) * (S - 1) - 1;
RTC_DCHECK_LE(static_cast<size_t>(0), index);
RTC_DCHECK_LT(index, buf_.size());
buf_[index] = values[i];
}
}
// Reads the value that corresponds to comparison of two items in the ring
// buffer having delay |delay1| and |delay2|. The two arguments must not be
// equal and both must be in {0, ..., S - 1}.
T GetValue(size_t delay1, size_t delay2) const {
int row = S - 1 - static_cast<int>(delay1);
int col = S - 1 - static_cast<int>(delay2);
RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed.";
if (row > col)
std::swap(row, col); // Swap to access the upper-right triangular part.
RTC_DCHECK_LE(0, row);
RTC_DCHECK_LT(row, S - 1) << "Not enforcing row < col and row != col.";
RTC_DCHECK_LE(1, col) << "Not enforcing row < col and row != col.";
RTC_DCHECK_LT(col, S);
const int index = row * (S - 1) + (col - 1);
RTC_DCHECK_LE(0, index);
RTC_DCHECK_LT(index, buf_.size());
return buf_[index];
}
private:
// Encode an upper-right triangular matrix (excluding its diagonal) using a
// square matrix. This allows to move the data in Push() with one single
// operation.
std::array<T, (S - 1) * (S - 1)> buf_{};
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_

View File

@ -0,0 +1,129 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include <memory>
#include "rtc_base/checks.h"
#include "rtc_base/system/arch.h"
#include "system_wrappers/include/cpu_features_wrapper.h"
#include "test/gtest.h"
#include "test/testsupport/file_utils.h"
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
using ReaderPairType =
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>;
} // namespace
using webrtc::test::ResourcePath;
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed) {
ASSERT_EQ(expected.size(), computed.size());
for (size_t i = 0; i < expected.size(); ++i) {
SCOPED_TRACE(i);
EXPECT_FLOAT_EQ(expected[i], computed[i]);
}
}
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
float tolerance) {
ASSERT_EQ(expected.size(), computed.size());
for (size_t i = 0; i < expected.size(); ++i) {
SCOPED_TRACE(i);
EXPECT_NEAR(expected[i], computed[i], tolerance);
}
}
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
CreatePcmSamplesReader(const size_t frame_length) {
auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
frame_length);
// The last incomplete frame is ignored.
return {std::move(ptr), ptr->data_length() / frame_length};
}
ReaderPairType CreatePitchBuffer24kHzReader() {
constexpr size_t cols = 864;
auto ptr = std::make_unique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
}
ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
constexpr size_t num_lp_residual_coeffs = 864;
auto ptr = std::make_unique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
num_lp_residual_coeffs);
return {std::move(ptr),
rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
}
ReaderPairType CreateVadProbsReader() {
auto ptr = std::make_unique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
return {std::move(ptr), ptr->data_length()};
}
PitchTestData::PitchTestData() {
BinaryFileReader<float> test_data_reader(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
static_cast<size_t>(1396));
test_data_reader.ReadChunk(test_data_);
}
PitchTestData::~PitchTestData() = default;
rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
const {
return {test_data_.data(), kBufSize24kHz};
}
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
PitchTestData::GetPitchBufSquareEnergiesView() const {
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
}
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
kNumPitchBufAutoCorrCoeffs};
}
bool IsOptimizationAvailable(Optimization optimization) {
switch (optimization) {
case Optimization::kSse2:
#if defined(WEBRTC_ARCH_X86_FAMILY)
return GetCPUInfo(kSSE2) != 0;
#else
return false;
#endif
case Optimization::kNeon:
#if defined(WEBRTC_HAS_NEON)
return true;
#else
return false;
#endif
case Optimization::kNone:
return true;
}
}
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,161 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#include <algorithm>
#include <array>
#include <fstream>
#include <limits>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace test {
constexpr float kFloatMin = std::numeric_limits<float>::min();
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
// that the values in the pair do not match.
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed);
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
// that their absolute error is above a given threshold.
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
float tolerance);
// Reader for binary files consisting of an arbitrary long sequence of elements
// having type T. It is possible to read and cast to another type D at once.
template <typename T, typename D = T>
class BinaryFileReader {
public:
explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0)
: is_(file_path, std::ios::binary | std::ios::ate),
data_length_(is_.tellg() / sizeof(T)),
chunk_size_(chunk_size) {
RTC_CHECK(is_);
SeekBeginning();
buf_.resize(chunk_size_);
}
BinaryFileReader(const BinaryFileReader&) = delete;
BinaryFileReader& operator=(const BinaryFileReader&) = delete;
~BinaryFileReader() = default;
size_t data_length() const { return data_length_; }
bool ReadValue(D* dst) {
if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst), sizeof(T));
} else {
T v;
is_.read(reinterpret_cast<char*>(&v), sizeof(T));
*dst = static_cast<D>(v);
}
return is_.gcount() == sizeof(T);
}
// If |chunk_size| was specified in the ctor, it will check that the size of
// |dst| equals |chunk_size|.
bool ReadChunk(rtc::ArrayView<D> dst) {
RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size()));
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
} else {
is_.read(reinterpret_cast<char*>(buf_.data()), bytes_to_read);
std::transform(buf_.begin(), buf_.end(), dst.begin(),
[](const T& v) -> D { return static_cast<D>(v); });
}
return is_.gcount() == bytes_to_read;
}
void SeekForward(size_t items) { is_.seekg(items * sizeof(T), is_.cur); }
void SeekBeginning() { is_.seekg(0, is_.beg); }
private:
std::ifstream is_;
const size_t data_length_;
const size_t chunk_size_;
std::vector<T> buf_;
};
// Writer for binary files.
template <typename T>
class BinaryFileWriter {
public:
explicit BinaryFileWriter(const std::string& file_path)
: os_(file_path, std::ios::binary) {}
BinaryFileWriter(const BinaryFileWriter&) = delete;
BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
~BinaryFileWriter() = default;
static_assert(std::is_arithmetic<T>::value, "");
void WriteChunk(rtc::ArrayView<const T> value) {
const std::streamsize bytes_to_write = value.size() * sizeof(T);
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
}
private:
std::ofstream os_;
};
// Factories for resource file readers.
// The functions below return a pair where the first item is a reader unique
// pointer and the second the number of chunks that can be read from the file.
// Creates a reader for the PCM samples that casts from S16 to float and reads
// chunks with length |frame_length|.
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
CreatePcmSamplesReader(const size_t frame_length);
// Creates a reader for the pitch buffer content at 24 kHz.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreatePitchBuffer24kHzReader();
// Creates a reader for the the LP residual coefficients and the pitch period
// and gain values.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateLpResidualAndPitchPeriodGainReader();
// Creates a reader for the VAD probabilities.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateVadProbsReader();
constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
constexpr size_t kNumPitchBufSquareEnergies = 385;
constexpr size_t kPitchTestDataSize =
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
// Class to retrieve a test pitch buffer content and the expected output for the
// analysis steps.
class PitchTestData {
public:
PitchTestData();
~PitchTestData();
rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
GetPitchBufSquareEnergiesView() const;
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
GetPitchBufAutoCorrCoeffsView() const;
private:
std::array<float, kPitchTestDataSize> test_data_;
};
// Returns true if the given optimization is available.
bool IsOptimizationAvailable(Optimization optimization);
} // namespace test
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_

View File

@ -0,0 +1,121 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
constexpr float kMinLevelDbfs = -90.f;
// Min/max margins are based on speech crest-factor.
constexpr float kMinMarginDb = 12.f;
constexpr float kMaxMarginDb = 25.f;
using saturation_protector_impl::RingBuffer;
} // namespace
bool RingBuffer::operator==(const RingBuffer& b) const {
RTC_DCHECK_LE(size_, buffer_.size());
RTC_DCHECK_LE(b.size_, b.buffer_.size());
if (size_ != b.size_) {
return false;
}
for (int i = 0, i0 = FrontIndex(), i1 = b.FrontIndex(); i < size_;
++i, ++i0, ++i1) {
if (buffer_[i0 % buffer_.size()] != b.buffer_[i1 % b.buffer_.size()]) {
return false;
}
}
return true;
}
void RingBuffer::Reset() {
next_ = 0;
size_ = 0;
}
void RingBuffer::PushBack(float v) {
RTC_DCHECK_GE(next_, 0);
RTC_DCHECK_GE(size_, 0);
RTC_DCHECK_LT(next_, buffer_.size());
RTC_DCHECK_LE(size_, buffer_.size());
buffer_[next_++] = v;
if (rtc::SafeEq(next_, buffer_.size())) {
next_ = 0;
}
if (rtc::SafeLt(size_, buffer_.size())) {
size_++;
}
}
absl::optional<float> RingBuffer::Front() const {
if (size_ == 0) {
return absl::nullopt;
}
RTC_DCHECK_LT(FrontIndex(), buffer_.size());
return buffer_[FrontIndex()];
}
bool SaturationProtectorState::operator==(
const SaturationProtectorState& b) const {
return margin_db == b.margin_db && peak_delay_buffer == b.peak_delay_buffer &&
max_peaks_dbfs == b.max_peaks_dbfs &&
time_since_push_ms == b.time_since_push_ms;
}
void ResetSaturationProtectorState(float initial_margin_db,
SaturationProtectorState& state) {
state.margin_db = initial_margin_db;
state.peak_delay_buffer.Reset();
state.max_peaks_dbfs = kMinLevelDbfs;
state.time_since_push_ms = 0;
}
void UpdateSaturationProtectorState(float speech_peak_dbfs,
float speech_level_dbfs,
SaturationProtectorState& state) {
// Get the max peak over `kPeakEnveloperSuperFrameLengthMs` ms.
state.max_peaks_dbfs = std::max(state.max_peaks_dbfs, speech_peak_dbfs);
state.time_since_push_ms += kFrameDurationMs;
if (rtc::SafeGt(state.time_since_push_ms, kPeakEnveloperSuperFrameLengthMs)) {
// Push `max_peaks_dbfs` back into the ring buffer.
state.peak_delay_buffer.PushBack(state.max_peaks_dbfs);
// Reset.
state.max_peaks_dbfs = kMinLevelDbfs;
state.time_since_push_ms = 0;
}
// Update margin by comparing the estimated speech level and the delayed max
// speech peak power.
// TODO(alessiob): Check with aleloi@ why we use a delay and how to tune it.
const float delayed_peak_dbfs =
state.peak_delay_buffer.Front().value_or(state.max_peaks_dbfs);
const float difference_db = delayed_peak_dbfs - speech_level_dbfs;
if (difference_db > state.margin_db) {
// Attack.
state.margin_db =
state.margin_db * kSaturationProtectorAttackConstant +
difference_db * (1.f - kSaturationProtectorAttackConstant);
} else {
// Decay.
state.margin_db = state.margin_db * kSaturationProtectorDecayConstant +
difference_db * (1.f - kSaturationProtectorDecayConstant);
}
state.margin_db =
rtc::SafeClamp<float>(state.margin_db, kMinMarginDb, kMaxMarginDb);
}
} // namespace webrtc

View File

@ -0,0 +1,82 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
#include <array>
#include "absl/types/optional.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace saturation_protector_impl {
// Ring buffer which only supports (i) push back and (ii) read oldest item.
class RingBuffer {
public:
bool operator==(const RingBuffer& b) const;
inline bool operator!=(const RingBuffer& b) const { return !(*this == b); }
// Maximum number of values that the buffer can contain.
int Capacity() const { return buffer_.size(); }
// Number of values in the buffer.
int Size() const { return size_; }
void Reset();
// Pushes back `v`. If the buffer is full, the oldest value is replaced.
void PushBack(float v);
// Returns the oldest item in the buffer. Returns an empty value if the
// buffer is empty.
absl::optional<float> Front() const;
private:
inline int FrontIndex() const {
return rtc::SafeEq(size_, buffer_.size()) ? next_ : 0;
}
// `buffer_` has `size_` elements (up to the size of `buffer_`) and `next_` is
// the position where the next new value is written in `buffer_`.
std::array<float, kPeakEnveloperBufferSize> buffer_;
int next_ = 0;
int size_ = 0;
};
} // namespace saturation_protector_impl
// Saturation protector state. Exposed publicly for check-pointing and restore
// ops.
struct SaturationProtectorState {
bool operator==(const SaturationProtectorState& s) const;
inline bool operator!=(const SaturationProtectorState& s) const {
return !(*this == s);
}
float margin_db; // Recommended margin.
saturation_protector_impl::RingBuffer peak_delay_buffer;
float max_peaks_dbfs;
int time_since_push_ms; // Time since the last ring buffer push operation.
};
// Resets the saturation protector state.
void ResetSaturationProtectorState(float initial_margin_db,
SaturationProtectorState& state);
// Updates `state` by analyzing the estimated speech level `speech_level_dbfs`
// and the peak power `speech_peak_dbfs` for an observed frame which is
// reliably classified as "speech". `state` must not be modified without calling
// this function.
void UpdateSaturationProtectorState(float speech_peak_dbfs,
float speech_level_dbfs,
SaturationProtectorState& state);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_

View File

@ -0,0 +1,177 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/signal_classifier.h"
#include <algorithm>
#include <numeric>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/down_sampler.h"
#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "system_wrappers/include/cpu_features_wrapper.h"
namespace webrtc {
namespace {
bool IsSse2Available() {
#if defined(WEBRTC_ARCH_X86_FAMILY)
return GetCPUInfo(kSSE2) != 0;
#else
return false;
#endif
}
void RemoveDcLevel(rtc::ArrayView<float> x) {
RTC_DCHECK_LT(0, x.size());
float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f);
mean /= x.size();
for (float& v : x) {
v -= mean;
}
}
void PowerSpectrum(const OouraFft* ooura_fft,
rtc::ArrayView<const float> x,
rtc::ArrayView<float> spectrum) {
RTC_DCHECK_EQ(65, spectrum.size());
RTC_DCHECK_EQ(128, x.size());
float X[128];
std::copy(x.data(), x.data() + x.size(), X);
ooura_fft->Fft(X);
float* X_p = X;
RTC_DCHECK_EQ(X_p, &X[0]);
spectrum[0] = (*X_p) * (*X_p);
++X_p;
RTC_DCHECK_EQ(X_p, &X[1]);
spectrum[64] = (*X_p) * (*X_p);
for (int k = 1; k < 64; ++k) {
++X_p;
RTC_DCHECK_EQ(X_p, &X[2 * k]);
spectrum[k] = (*X_p) * (*X_p);
++X_p;
RTC_DCHECK_EQ(X_p, &X[2 * k + 1]);
spectrum[k] += (*X_p) * (*X_p);
}
}
webrtc::SignalClassifier::SignalType ClassifySignal(
rtc::ArrayView<const float> signal_spectrum,
rtc::ArrayView<const float> noise_spectrum,
ApmDataDumper* data_dumper) {
int num_stationary_bands = 0;
int num_highly_nonstationary_bands = 0;
// Detect stationary and highly nonstationary bands.
for (size_t k = 1; k < 40; k++) {
if (signal_spectrum[k] < 3 * noise_spectrum[k] &&
signal_spectrum[k] * 3 > noise_spectrum[k]) {
++num_stationary_bands;
} else if (signal_spectrum[k] > 9 * noise_spectrum[k]) {
++num_highly_nonstationary_bands;
}
}
data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands);
data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1,
&num_highly_nonstationary_bands);
// Use the detected number of bands to classify the overall signal
// stationarity.
if (num_stationary_bands > 15) {
return SignalClassifier::SignalType::kStationary;
} else {
return SignalClassifier::SignalType::kNonStationary;
}
}
} // namespace
SignalClassifier::FrameExtender::FrameExtender(size_t frame_size,
size_t extended_frame_size)
: x_old_(extended_frame_size - frame_size, 0.f) {}
SignalClassifier::FrameExtender::~FrameExtender() = default;
void SignalClassifier::FrameExtender::ExtendFrame(
rtc::ArrayView<const float> x,
rtc::ArrayView<float> x_extended) {
RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size());
std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data());
std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size());
std::copy(x_extended.data() + x_extended.size() - x_old_.size(),
x_extended.data() + x_extended.size(), x_old_.data());
}
SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper)
: data_dumper_(data_dumper),
down_sampler_(data_dumper_),
noise_spectrum_estimator_(data_dumper_),
ooura_fft_(IsSse2Available()) {
Initialize(48000);
}
SignalClassifier::~SignalClassifier() {}
void SignalClassifier::Initialize(int sample_rate_hz) {
down_sampler_.Initialize(sample_rate_hz);
noise_spectrum_estimator_.Initialize();
frame_extender_.reset(new FrameExtender(80, 128));
sample_rate_hz_ = sample_rate_hz;
initialization_frames_left_ = 2;
consistent_classification_counter_ = 3;
last_signal_type_ = SignalClassifier::SignalType::kNonStationary;
}
SignalClassifier::SignalType SignalClassifier::Analyze(
rtc::ArrayView<const float> signal) {
RTC_DCHECK_EQ(signal.size(), sample_rate_hz_ / 100);
// Compute the signal power spectrum.
float downsampled_frame[80];
down_sampler_.DownSample(signal, downsampled_frame);
float extended_frame[128];
frame_extender_->ExtendFrame(downsampled_frame, extended_frame);
RemoveDcLevel(extended_frame);
float signal_spectrum[65];
PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum);
// Classify the signal based on the estimate of the noise spectrum and the
// signal spectrum estimate.
const SignalType signal_type = ClassifySignal(
signal_spectrum, noise_spectrum_estimator_.GetNoiseSpectrum(),
data_dumper_);
// Update the noise spectrum based on the signal spectrum.
noise_spectrum_estimator_.Update(signal_spectrum,
initialization_frames_left_ > 0);
// Update the number of frames until a reliable signal spectrum is achieved.
initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1);
if (last_signal_type_ == signal_type) {
consistent_classification_counter_ =
std::max(0, consistent_classification_counter_ - 1);
} else {
last_signal_type_ = signal_type;
consistent_classification_counter_ = 3;
}
if (consistent_classification_counter_ > 0) {
return SignalClassifier::SignalType::kNonStationary;
}
return signal_type;
}
} // namespace webrtc

View File

@ -0,0 +1,73 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "common_audio/third_party/ooura/fft_size_128/ooura_fft.h"
#include "modules/audio_processing/agc2/down_sampler.h"
#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
namespace webrtc {
class ApmDataDumper;
class AudioBuffer;
class SignalClassifier {
public:
enum class SignalType { kNonStationary, kStationary };
explicit SignalClassifier(ApmDataDumper* data_dumper);
SignalClassifier() = delete;
SignalClassifier(const SignalClassifier&) = delete;
SignalClassifier& operator=(const SignalClassifier&) = delete;
~SignalClassifier();
void Initialize(int sample_rate_hz);
SignalType Analyze(rtc::ArrayView<const float> signal);
private:
class FrameExtender {
public:
FrameExtender(size_t frame_size, size_t extended_frame_size);
FrameExtender() = delete;
FrameExtender(const FrameExtender&) = delete;
FrameExtender& operator=(const FrameExtender&) = delete;
~FrameExtender();
void ExtendFrame(rtc::ArrayView<const float> x,
rtc::ArrayView<float> x_extended);
private:
std::vector<float> x_old_;
};
ApmDataDumper* const data_dumper_;
DownSampler down_sampler_;
std::unique_ptr<FrameExtender> frame_extender_;
NoiseSpectrumEstimator noise_spectrum_estimator_;
int sample_rate_hz_;
int initialization_frames_left_;
int consistent_classification_counter_;
SignalType last_signal_type_;
const OouraFft ooura_fft_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_

View File

@ -0,0 +1,114 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/vad_with_level.h"
#include <algorithm>
#include <array>
#include <cmath>
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector;
// Default VAD that combines a resampler and the RNN VAD.
// Computes the speech probability on the first channel.
class Vad : public VoiceActivityDetector {
public:
Vad() = default;
Vad(const Vad&) = delete;
Vad& operator=(const Vad&) = delete;
~Vad() = default;
float ComputeProbability(AudioFrameView<const float> frame) override {
// The source number of channels is 1, because we always use the 1st
// channel.
resampler_.InitializeIfNeeded(
/*sample_rate_hz=*/static_cast<int>(frame.samples_per_channel() * 100),
rnn_vad::kSampleRate24kHz,
/*num_channels=*/1);
std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
// Feed the 1st channel to the resampler.
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
work_frame, feature_vector);
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
}
private:
PushResampler<float> resampler_;
rnn_vad::FeaturesExtractor features_extractor_;
rnn_vad::RnnBasedVad rnn_vad_;
};
// Returns an updated version of `p_old` by using instant decay and the given
// `attack` on a new VAD probability value `p_new`.
float SmoothedVadProbability(float p_old, float p_new, float attack) {
RTC_DCHECK_GT(attack, 0.f);
RTC_DCHECK_LE(attack, 1.f);
if (p_new < p_old || attack == 1.f) {
// Instant decay (or no smoothing).
return p_new;
} else {
// Attack phase.
return attack * p_new + (1.f - attack) * p_old;
}
}
} // namespace
VadLevelAnalyzer::VadLevelAnalyzer()
: VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack,
std::make_unique<Vad>()) {}
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack)
: VadLevelAnalyzer(vad_probability_attack, std::make_unique<Vad>()) {}
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
std::unique_ptr<VoiceActivityDetector> vad)
: vad_(std::move(vad)), vad_probability_attack_(vad_probability_attack) {
RTC_DCHECK(vad_);
}
VadLevelAnalyzer::~VadLevelAnalyzer() = default;
VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
AudioFrameView<const float> frame) {
// Compute levels.
float peak = 0.f;
float rms = 0.f;
for (const auto& x : frame.channel(0)) {
peak = std::max(std::fabs(x), peak);
rms += x * x;
}
// Compute smoothed speech probability.
vad_probability_ = SmoothedVadProbability(
/*p_old=*/vad_probability_, /*p_new=*/vad_->ComputeProbability(frame),
vad_probability_attack_);
return {vad_probability_,
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
FloatS16ToDbfs(peak)};
}
} // namespace webrtc

View File

@ -0,0 +1,58 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
#include <memory>
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
// Class to analyze voice activity and audio levels.
class VadLevelAnalyzer {
public:
struct Result {
float speech_probability; // Range: [0, 1].
float rms_dbfs; // Root mean square power (dBFS).
float peak_dbfs; // Peak power (dBFS).
};
// Voice Activity Detector (VAD) interface.
class VoiceActivityDetector {
public:
virtual ~VoiceActivityDetector() = default;
// Analyzes an audio frame and returns the speech probability.
virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
};
// Ctor. Uses the default VAD.
VadLevelAnalyzer();
explicit VadLevelAnalyzer(float vad_probability_attack);
// Ctor. Uses a custom `vad`.
VadLevelAnalyzer(float vad_probability_attack,
std::unique_ptr<VoiceActivityDetector> vad);
VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
~VadLevelAnalyzer();
// Computes the speech probability and the level for `frame`.
Result AnalyzeFrame(AudioFrameView<const float> frame);
private:
std::unique_ptr<VoiceActivityDetector> vad_;
const float vad_probability_attack_;
float vad_probability_ = 0.f;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_

View File

@ -0,0 +1,39 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/vector_float_frame.h"
namespace webrtc {
namespace {
std::vector<float*> ConstructChannelPointers(
std::vector<std::vector<float>>* x) {
std::vector<float*> channel_ptrs;
for (auto& v : *x) {
channel_ptrs.push_back(v.data());
}
return channel_ptrs;
}
} // namespace
VectorFloatFrame::VectorFloatFrame(int num_channels,
int samples_per_channel,
float start_value)
: channels_(num_channels,
std::vector<float>(samples_per_channel, start_value)),
channel_ptrs_(ConstructChannelPointers(&channels_)),
float_frame_view_(channel_ptrs_.data(),
channels_.size(),
samples_per_channel) {}
VectorFloatFrame::~VectorFloatFrame() = default;
} // namespace webrtc

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
#define MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
#include <vector>
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
// A construct consisting of a multi-channel audio frame, and a FloatFrame view
// of it.
class VectorFloatFrame {
public:
VectorFloatFrame(int num_channels,
int samples_per_channel,
float start_value);
const AudioFrameView<float>& float_frame_view() { return float_frame_view_; }
AudioFrameView<const float> float_frame_view() const {
return float_frame_view_;
}
~VectorFloatFrame();
private:
std::vector<std::vector<float>> channels_;
std::vector<float*> channel_ptrs_;
AudioFrameView<float> float_frame_view_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_