Bump to WebRTC M120 release
Some API deprecation -- ExperimentalAgc and ExperimentalNs are gone. We're continuing to carry iSAC even though it's gone upstream, but maybe we'll want to drop that soon.
This commit is contained in:
@ -17,6 +17,7 @@ rtc_library("rnn_vad") {
|
||||
"rnn.h",
|
||||
]
|
||||
|
||||
defines = []
|
||||
if (rtc_build_with_neon && current_cpu != "arm64") {
|
||||
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
|
||||
cflags = [ "-mfpu=neon" ]
|
||||
@ -24,16 +25,17 @@ rtc_library("rnn_vad") {
|
||||
|
||||
deps = [
|
||||
":rnn_vad_common",
|
||||
":rnn_vad_layers",
|
||||
":rnn_vad_lp_residual",
|
||||
":rnn_vad_pitch",
|
||||
":rnn_vad_sequence_buffer",
|
||||
":rnn_vad_spectral_features",
|
||||
"..:biquad_filter",
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../api:function_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
]
|
||||
}
|
||||
@ -51,16 +53,13 @@ rtc_library("rnn_vad_auto_correlation") {
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_common") {
|
||||
rtc_source_set("rnn_vad_common") {
|
||||
# TODO(alessiob): Make this target visibility private.
|
||||
visibility = [
|
||||
":*",
|
||||
"..:rnn_vad_with_level",
|
||||
]
|
||||
sources = [
|
||||
"common.cc",
|
||||
"common.h",
|
||||
"..:vad_wrapper",
|
||||
]
|
||||
sources = [ "common.h" ]
|
||||
deps = [
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../system_wrappers",
|
||||
@ -75,23 +74,100 @@ rtc_library("rnn_vad_lp_residual") {
|
||||
deps = [
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("rnn_vad_layers") {
|
||||
sources = [
|
||||
"rnn_fc.cc",
|
||||
"rnn_fc.h",
|
||||
"rnn_gru.cc",
|
||||
"rnn_gru.h",
|
||||
]
|
||||
|
||||
defines = []
|
||||
if (rtc_build_with_neon && current_cpu != "arm64") {
|
||||
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
|
||||
cflags = [ "-mfpu=neon" ]
|
||||
}
|
||||
|
||||
deps = [
|
||||
":rnn_vad_common",
|
||||
":vector_math",
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../api:function_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
]
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
deps += [ ":vector_math_avx2" ]
|
||||
}
|
||||
absl_deps = [ "//third_party/abseil-cpp/absl/strings" ]
|
||||
}
|
||||
|
||||
rtc_source_set("vector_math") {
|
||||
sources = [ "vector_math.h" ]
|
||||
deps = [
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"../../../../rtc_base/system:arch",
|
||||
]
|
||||
}
|
||||
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
rtc_library("vector_math_avx2") {
|
||||
sources = [ "vector_math_avx2.cc" ]
|
||||
if (is_win) {
|
||||
cflags = [ "/arch:AVX2" ]
|
||||
} else {
|
||||
cflags = [
|
||||
"-mavx2",
|
||||
"-mfma",
|
||||
]
|
||||
}
|
||||
deps = [
|
||||
":vector_math",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_pitch") {
|
||||
sources = [
|
||||
"pitch_info.h",
|
||||
"pitch_search.cc",
|
||||
"pitch_search.h",
|
||||
"pitch_search_internal.cc",
|
||||
"pitch_search_internal.h",
|
||||
]
|
||||
|
||||
defines = []
|
||||
if (rtc_build_with_neon && current_cpu != "arm64") {
|
||||
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
|
||||
cflags = [ "-mfpu=neon" ]
|
||||
}
|
||||
|
||||
deps = [
|
||||
":rnn_vad_auto_correlation",
|
||||
":rnn_vad_common",
|
||||
":vector_math",
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:gtest_prod",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"../../../../rtc_base/system:arch",
|
||||
]
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
deps += [ ":vector_math_avx2" ]
|
||||
}
|
||||
}
|
||||
|
||||
rtc_source_set("rnn_vad_ring_buffer") {
|
||||
@ -123,6 +199,7 @@ rtc_library("rnn_vad_spectral_features") {
|
||||
":rnn_vad_symmetric_matrix_buffer",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../utility:pffft_wrapper",
|
||||
]
|
||||
}
|
||||
@ -132,6 +209,7 @@ rtc_source_set("rnn_vad_symmetric_matrix_buffer") {
|
||||
deps = [
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
]
|
||||
}
|
||||
|
||||
@ -148,11 +226,11 @@ if (rtc_include_tests) {
|
||||
"../../../../api:array_view",
|
||||
"../../../../api:scoped_refptr",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../system_wrappers",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../test:fileutils",
|
||||
"../../../../test:test_support",
|
||||
]
|
||||
absl_deps = [ "//third_party/abseil-cpp/absl/strings" ]
|
||||
}
|
||||
|
||||
unittest_resources = [
|
||||
@ -181,17 +259,28 @@ if (rtc_include_tests) {
|
||||
"pitch_search_internal_unittest.cc",
|
||||
"pitch_search_unittest.cc",
|
||||
"ring_buffer_unittest.cc",
|
||||
"rnn_fc_unittest.cc",
|
||||
"rnn_gru_unittest.cc",
|
||||
"rnn_unittest.cc",
|
||||
"rnn_vad_unittest.cc",
|
||||
"sequence_buffer_unittest.cc",
|
||||
"spectral_features_internal_unittest.cc",
|
||||
"spectral_features_unittest.cc",
|
||||
"symmetric_matrix_buffer_unittest.cc",
|
||||
"vector_math_unittest.cc",
|
||||
]
|
||||
|
||||
defines = []
|
||||
if (rtc_build_with_neon && current_cpu != "arm64") {
|
||||
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
|
||||
cflags = [ "-mfpu=neon" ]
|
||||
}
|
||||
|
||||
deps = [
|
||||
":rnn_vad",
|
||||
":rnn_vad_auto_correlation",
|
||||
":rnn_vad_common",
|
||||
":rnn_vad_layers",
|
||||
":rnn_vad_lp_residual",
|
||||
":rnn_vad_pitch",
|
||||
":rnn_vad_ring_buffer",
|
||||
@ -199,20 +288,47 @@ if (rtc_include_tests) {
|
||||
":rnn_vad_spectral_features",
|
||||
":rnn_vad_symmetric_matrix_buffer",
|
||||
":test_utils",
|
||||
":vector_math",
|
||||
"..:cpu_features",
|
||||
"../..:audioproc_test_utils",
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio/",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"../../../../rtc_base:stringutils",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../test:test_support",
|
||||
"../../utility:pffft_wrapper",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
]
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
deps += [ ":vector_math_avx2" ]
|
||||
}
|
||||
absl_deps = [ "//third_party/abseil-cpp/absl/memory" ]
|
||||
data = unittest_resources
|
||||
if (is_ios) {
|
||||
deps += [ ":unittests_bundle_data" ]
|
||||
}
|
||||
}
|
||||
|
||||
if (!build_with_chromium) {
|
||||
rtc_executable("rnn_vad_tool") {
|
||||
testonly = true
|
||||
sources = [ "rnn_vad_tool.cc" ]
|
||||
deps = [
|
||||
":rnn_vad",
|
||||
":rnn_vad_common",
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../test:test_support",
|
||||
"//third_party/abseil-cpp/absl/flags:flag",
|
||||
"//third_party/abseil-cpp/absl/flags:parse",
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ namespace {
|
||||
|
||||
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
|
||||
static_assert(1 << kAutoCorrelationFftOrder >
|
||||
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
|
||||
kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
|
||||
"");
|
||||
|
||||
} // namespace
|
||||
@ -40,20 +40,20 @@ AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
|
||||
// [ y_{m-1} ]
|
||||
// x and y are sub-array of equal length; x is never moved, whereas y slides.
|
||||
// The cross-correlation between y_0 and x corresponds to the auto-correlation
|
||||
// for the maximum pitch period. Hence, the first value in |auto_corr| has an
|
||||
// for the maximum pitch period. Hence, the first value in `auto_corr` has an
|
||||
// inverted lag equal to 0 that corresponds to a lag equal to the maximum
|
||||
// pitch period.
|
||||
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
|
||||
rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
|
||||
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
|
||||
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
|
||||
constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder;
|
||||
constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
|
||||
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
|
||||
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
|
||||
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
|
||||
"Mismatch between pitch buffer size, frame size and maximum "
|
||||
"pitch period.");
|
||||
static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
|
||||
static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
|
||||
"The FFT length is not sufficiently big to avoid cyclic "
|
||||
"convolution errors.");
|
||||
auto tmp = tmp_->GetView();
|
||||
@ -67,13 +67,12 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
|
||||
|
||||
// Compute the FFT for the sliding frames chunk. The sliding frames are
|
||||
// defined as pitch_buf[i:i+kConvolutionLength] where i in
|
||||
// [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
|
||||
// defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
|
||||
// [0, kNumLags12kHz). The chunk includes all of them, hence it is
|
||||
// defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
|
||||
std::copy(pitch_buf.begin(),
|
||||
pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
|
||||
pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
|
||||
tmp.begin());
|
||||
std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
|
||||
0.f);
|
||||
std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
|
||||
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
|
||||
|
||||
// Convolve in the frequency domain.
|
||||
@ -84,7 +83,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
|
||||
|
||||
// Extract the auto-correlation coefficients.
|
||||
std::copy(tmp.begin() + kConvolutionLength - 1,
|
||||
tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
|
||||
tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
|
||||
auto_corr.begin());
|
||||
}
|
||||
|
||||
|
@ -31,10 +31,10 @@ class AutoCorrelationCalculator {
|
||||
~AutoCorrelationCalculator();
|
||||
|
||||
// Computes the auto-correlation coefficients for a target pitch interval.
|
||||
// |auto_corr| indexes are inverted lags.
|
||||
// `auto_corr` indexes are inverted lags.
|
||||
void ComputeOnPitchBuffer(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
|
||||
rtc::ArrayView<float, kNumLags12kHz> auto_corr);
|
||||
|
||||
private:
|
||||
Pffft fft_;
|
||||
|
@ -1,34 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
Optimization DetectOptimization() {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
if (GetCPUInfo(kSSE2) != 0) {
|
||||
return Optimization::kSse2;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
return Optimization::kNeon;
|
||||
#endif
|
||||
|
||||
return Optimization::kNone;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
@ -18,57 +18,58 @@ namespace rnn_vad {
|
||||
|
||||
constexpr double kPi = 3.14159265358979323846;
|
||||
|
||||
constexpr size_t kSampleRate24kHz = 24000;
|
||||
constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100;
|
||||
constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
|
||||
constexpr int kSampleRate24kHz = 24000;
|
||||
constexpr int kFrameSize10ms24kHz = kSampleRate24kHz / 100;
|
||||
constexpr int kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
|
||||
|
||||
// Pitch buffer.
|
||||
constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
|
||||
constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
|
||||
constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
|
||||
constexpr int kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
|
||||
constexpr int kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
|
||||
constexpr int kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
|
||||
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
|
||||
|
||||
// 24 kHz analysis.
|
||||
// Define a higher minimum pitch period for the initial search. This is used to
|
||||
// avoid searching for very short periods, for which a refinement step is
|
||||
// responsible.
|
||||
constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
|
||||
constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
|
||||
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
|
||||
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
|
||||
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
|
||||
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
|
||||
// Number of (inverted) lags during the initial pitch search phase at 24 kHz.
|
||||
constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
|
||||
// Number of (inverted) lags during the pitch search refinement phase at 24 kHz.
|
||||
constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1;
|
||||
static_assert(
|
||||
kRefineNumLags24kHz > kInitialNumLags24kHz,
|
||||
"The refinement step must search the pitch in an extended pitch range.");
|
||||
|
||||
// 12 kHz analysis.
|
||||
constexpr size_t kSampleRate12kHz = 12000;
|
||||
constexpr size_t kFrameSize10ms12kHz = kSampleRate12kHz / 100;
|
||||
constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
|
||||
constexpr size_t kBufSize12kHz = kBufSize24kHz / 2;
|
||||
constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
|
||||
constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
|
||||
constexpr int kSampleRate12kHz = 12000;
|
||||
constexpr int kFrameSize10ms12kHz = kSampleRate12kHz / 100;
|
||||
constexpr int kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
|
||||
constexpr int kBufSize12kHz = kBufSize24kHz / 2;
|
||||
constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
|
||||
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
|
||||
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
|
||||
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
|
||||
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
|
||||
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
|
||||
// The inverted lags for the pitch interval [`kInitialMinPitch12kHz`,
|
||||
// `kMaxPitch12kHz`] are in the range [0, `kNumLags12kHz`].
|
||||
constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
|
||||
|
||||
// 48 kHz constants.
|
||||
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
|
||||
constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2;
|
||||
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;
|
||||
constexpr int kMaxPitch48kHz = kMaxPitch24kHz * 2;
|
||||
|
||||
// Spectral features.
|
||||
constexpr size_t kNumBands = 22;
|
||||
constexpr size_t kNumLowerBands = 6;
|
||||
constexpr int kNumBands = 22;
|
||||
constexpr int kNumLowerBands = 6;
|
||||
static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), "");
|
||||
constexpr size_t kCepstralCoeffsHistorySize = 8;
|
||||
constexpr int kCepstralCoeffsHistorySize = 8;
|
||||
static_assert(kCepstralCoeffsHistorySize > 2,
|
||||
"The history size must at least be 3 to compute first and second "
|
||||
"derivatives.");
|
||||
|
||||
constexpr size_t kFeatureVectorSize = 42;
|
||||
|
||||
enum class Optimization { kNone, kSse2, kNeon };
|
||||
|
||||
// Detects what kind of optimizations to use for the code.
|
||||
Optimization DetectOptimization();
|
||||
constexpr int kFeatureVectorSize = 42;
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -19,23 +19,23 @@ namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Generated via "B, A = scipy.signal.butter(2, 30/12000, btype='highpass')"
|
||||
const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = {
|
||||
// Computed as `scipy.signal.butter(N=2, Wn=60/24000, btype='highpass')`.
|
||||
constexpr BiQuadFilter::Config kHpfConfig24k{
|
||||
{0.99446179f, -1.98892358f, 0.99446179f},
|
||||
{-1.98889291f, 0.98895425f}};
|
||||
|
||||
} // namespace
|
||||
|
||||
FeaturesExtractor::FeaturesExtractor()
|
||||
FeaturesExtractor::FeaturesExtractor(const AvailableCpuFeatures& cpu_features)
|
||||
: use_high_pass_filter_(false),
|
||||
hpf_(kHpfConfig24k),
|
||||
pitch_buf_24kHz_(),
|
||||
pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()),
|
||||
lp_residual_(kBufSize24kHz),
|
||||
lp_residual_view_(lp_residual_.data(), kBufSize24kHz),
|
||||
pitch_estimator_(),
|
||||
pitch_estimator_(cpu_features),
|
||||
reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) {
|
||||
RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size());
|
||||
hpf_.Initialize(kHpfConfig24k);
|
||||
Reset();
|
||||
}
|
||||
|
||||
@ -44,8 +44,9 @@ FeaturesExtractor::~FeaturesExtractor() = default;
|
||||
void FeaturesExtractor::Reset() {
|
||||
pitch_buf_24kHz_.Reset();
|
||||
spectral_features_extractor_.Reset();
|
||||
if (use_high_pass_filter_)
|
||||
if (use_high_pass_filter_) {
|
||||
hpf_.Reset();
|
||||
}
|
||||
}
|
||||
|
||||
bool FeaturesExtractor::CheckSilenceComputeFeatures(
|
||||
@ -55,10 +56,10 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
|
||||
if (use_high_pass_filter_) {
|
||||
std::array<float, kFrameSize10ms24kHz> samples_filtered;
|
||||
hpf_.Process(samples, samples_filtered);
|
||||
// Feed buffer with the pre-processed version of |samples|.
|
||||
// Feed buffer with the pre-processed version of `samples`.
|
||||
pitch_buf_24kHz_.Push(samples_filtered);
|
||||
} else {
|
||||
// Feed buffer with |samples|.
|
||||
// Feed buffer with `samples`.
|
||||
pitch_buf_24kHz_.Push(samples);
|
||||
}
|
||||
// Extract the LP residual.
|
||||
@ -67,13 +68,12 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
|
||||
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
|
||||
// Estimate pitch on the LP-residual and write the normalized pitch period
|
||||
// into the output vector (normalization based on training data stats).
|
||||
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
|
||||
feature_vector[kFeatureVectorSize - 2] =
|
||||
0.01f * (static_cast<int>(pitch_info_48kHz_.period) - 300);
|
||||
pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
|
||||
feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300);
|
||||
// Extract lagged frames (according to the estimated pitch period).
|
||||
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
|
||||
RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz);
|
||||
auto lagged_frame = pitch_buf_24kHz_view_.subview(
|
||||
kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz);
|
||||
kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz);
|
||||
// Analyze reference and lagged frames checking if silence has been detected
|
||||
// and write the feature vector.
|
||||
return spectral_features_extractor_.CheckSilenceComputeFeatures(
|
||||
|
@ -16,7 +16,6 @@
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/biquad_filter.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
|
||||
@ -27,14 +26,14 @@ namespace rnn_vad {
|
||||
// Feature extractor to feed the VAD RNN.
|
||||
class FeaturesExtractor {
|
||||
public:
|
||||
FeaturesExtractor();
|
||||
explicit FeaturesExtractor(const AvailableCpuFeatures& cpu_features);
|
||||
FeaturesExtractor(const FeaturesExtractor&) = delete;
|
||||
FeaturesExtractor& operator=(const FeaturesExtractor&) = delete;
|
||||
~FeaturesExtractor();
|
||||
void Reset();
|
||||
// Analyzes the samples, computes the feature vector and returns true if
|
||||
// silence is detected (false if not). When silence is detected,
|
||||
// |feature_vector| is partially written and therefore must not be used to
|
||||
// `feature_vector` is partially written and therefore must not be used to
|
||||
// feed the VAD RNN.
|
||||
bool CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
|
||||
@ -53,7 +52,7 @@ class FeaturesExtractor {
|
||||
PitchEstimator pitch_estimator_;
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
|
||||
SpectralFeaturesExtractor spectral_features_extractor_;
|
||||
PitchInfo pitch_info_48kHz_;
|
||||
int pitch_period_48kHz_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
|
@ -16,27 +16,23 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Computes cross-correlation coefficients between |x| and |y| and writes them
|
||||
// in |x_corr|. The lag values are in {0, ..., max_lag - 1}, where max_lag
|
||||
// equals the size of |x_corr|.
|
||||
// The |x| and |y| sub-arrays used to compute a cross-correlation coefficients
|
||||
// for a lag l have both size "size of |x| - l" - i.e., the longest sub-array is
|
||||
// used. |x| and |y| must have the same size.
|
||||
void ComputeCrossCorrelation(
|
||||
// Computes auto-correlation coefficients for `x` and writes them in
|
||||
// `auto_corr`. The lag values are in {0, ..., max_lag - 1}, where max_lag
|
||||
// equals the size of `auto_corr`.
|
||||
void ComputeAutoCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> x_corr) {
|
||||
constexpr size_t max_lag = x_corr.size();
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
|
||||
constexpr int max_lag = auto_corr.size();
|
||||
RTC_DCHECK_LT(max_lag, x.size());
|
||||
for (size_t lag = 0; lag < max_lag; ++lag) {
|
||||
x_corr[lag] =
|
||||
std::inner_product(x.begin(), x.end() - lag, y.begin() + lag, 0.f);
|
||||
for (int lag = 0; lag < max_lag; ++lag) {
|
||||
auto_corr[lag] =
|
||||
std::inner_product(x.begin(), x.end() - lag, x.begin() + lag, 0.f);
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,9 +41,13 @@ void DenoiseAutoCorrelation(
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
|
||||
// Assume -40 dB white noise floor.
|
||||
auto_corr[0] *= 1.0001f;
|
||||
for (size_t i = 1; i < kNumLpcCoefficients; ++i) {
|
||||
auto_corr[i] -= auto_corr[i] * (0.008f * i) * (0.008f * i);
|
||||
}
|
||||
// Hard-coded values obtained as
|
||||
// [np.float32((0.008*0.008*i*i)) for i in range(1,5)].
|
||||
auto_corr[1] -= auto_corr[1] * 0.000064f;
|
||||
auto_corr[2] -= auto_corr[2] * 0.000256f;
|
||||
auto_corr[3] -= auto_corr[3] * 0.000576f;
|
||||
auto_corr[4] -= auto_corr[4] * 0.001024f;
|
||||
static_assert(kNumLpcCoefficients == 5, "Update `auto_corr`.");
|
||||
}
|
||||
|
||||
// Computes the initial inverse filter coefficients given the auto-correlation
|
||||
@ -56,9 +56,9 @@ void ComputeInitialInverseFilterCoefficients(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> auto_corr,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients - 1> lpc_coeffs) {
|
||||
float error = auto_corr[0];
|
||||
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
|
||||
for (int i = 0; i < kNumLpcCoefficients - 1; ++i) {
|
||||
float reflection_coeff = 0.f;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
for (int j = 0; j < i; ++j) {
|
||||
reflection_coeff += lpc_coeffs[j] * auto_corr[i - j];
|
||||
}
|
||||
reflection_coeff += auto_corr[i + 1];
|
||||
@ -72,7 +72,7 @@ void ComputeInitialInverseFilterCoefficients(
|
||||
reflection_coeff /= -error;
|
||||
// Update LPC coefficients and total error.
|
||||
lpc_coeffs[i] = reflection_coeff;
|
||||
for (size_t j = 0; j<(i + 1)>> 1; ++j) {
|
||||
for (int j = 0; j < ((i + 1) >> 1); ++j) {
|
||||
const float tmp1 = lpc_coeffs[j];
|
||||
const float tmp2 = lpc_coeffs[i - 1 - j];
|
||||
lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2;
|
||||
@ -91,46 +91,49 @@ void ComputeAndPostProcessLpcCoefficients(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs) {
|
||||
std::array<float, kNumLpcCoefficients> auto_corr;
|
||||
ComputeCrossCorrelation(x, x, {auto_corr.data(), auto_corr.size()});
|
||||
ComputeAutoCorrelation(x, auto_corr);
|
||||
if (auto_corr[0] == 0.f) { // Empty frame.
|
||||
std::fill(lpc_coeffs.begin(), lpc_coeffs.end(), 0);
|
||||
return;
|
||||
}
|
||||
DenoiseAutoCorrelation({auto_corr.data(), auto_corr.size()});
|
||||
DenoiseAutoCorrelation(auto_corr);
|
||||
std::array<float, kNumLpcCoefficients - 1> lpc_coeffs_pre{};
|
||||
ComputeInitialInverseFilterCoefficients(auto_corr, lpc_coeffs_pre);
|
||||
// LPC coefficients post-processing.
|
||||
// TODO(bugs.webrtc.org/9076): Consider removing these steps.
|
||||
float c1 = 1.f;
|
||||
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
|
||||
c1 *= 0.9f;
|
||||
lpc_coeffs_pre[i] *= c1;
|
||||
}
|
||||
const float c2 = 0.8f;
|
||||
lpc_coeffs[0] = lpc_coeffs_pre[0] + c2;
|
||||
lpc_coeffs[1] = lpc_coeffs_pre[1] + c2 * lpc_coeffs_pre[0];
|
||||
lpc_coeffs[2] = lpc_coeffs_pre[2] + c2 * lpc_coeffs_pre[1];
|
||||
lpc_coeffs[3] = lpc_coeffs_pre[3] + c2 * lpc_coeffs_pre[2];
|
||||
lpc_coeffs[4] = c2 * lpc_coeffs_pre[3];
|
||||
lpc_coeffs_pre[0] *= 0.9f;
|
||||
lpc_coeffs_pre[1] *= 0.9f * 0.9f;
|
||||
lpc_coeffs_pre[2] *= 0.9f * 0.9f * 0.9f;
|
||||
lpc_coeffs_pre[3] *= 0.9f * 0.9f * 0.9f * 0.9f;
|
||||
constexpr float kC = 0.8f;
|
||||
lpc_coeffs[0] = lpc_coeffs_pre[0] + kC;
|
||||
lpc_coeffs[1] = lpc_coeffs_pre[1] + kC * lpc_coeffs_pre[0];
|
||||
lpc_coeffs[2] = lpc_coeffs_pre[2] + kC * lpc_coeffs_pre[1];
|
||||
lpc_coeffs[3] = lpc_coeffs_pre[3] + kC * lpc_coeffs_pre[2];
|
||||
lpc_coeffs[4] = kC * lpc_coeffs_pre[3];
|
||||
static_assert(kNumLpcCoefficients == 5, "Update `lpc_coeffs(_pre)`.");
|
||||
}
|
||||
|
||||
void ComputeLpResidual(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> y) {
|
||||
RTC_DCHECK_LT(kNumLpcCoefficients, x.size());
|
||||
RTC_DCHECK_GT(x.size(), kNumLpcCoefficients);
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
std::array<float, kNumLpcCoefficients> input_chunk;
|
||||
input_chunk.fill(0.f);
|
||||
for (size_t i = 0; i < y.size(); ++i) {
|
||||
const float sum = std::inner_product(input_chunk.begin(), input_chunk.end(),
|
||||
lpc_coeffs.begin(), x[i]);
|
||||
// Circular shift and add a new sample.
|
||||
for (size_t j = kNumLpcCoefficients - 1; j > 0; --j)
|
||||
input_chunk[j] = input_chunk[j - 1];
|
||||
input_chunk[0] = x[i];
|
||||
// Copy result.
|
||||
y[i] = sum;
|
||||
// The code below implements the following operation:
|
||||
// y[i] = x[i] + dot_product({x[i], ..., x[i - kNumLpcCoefficients + 1]},
|
||||
// lpc_coeffs)
|
||||
// Edge case: i < kNumLpcCoefficients.
|
||||
y[0] = x[0];
|
||||
for (int i = 1; i < kNumLpcCoefficients; ++i) {
|
||||
y[i] =
|
||||
std::inner_product(x.crend() - i, x.crend(), lpc_coeffs.cbegin(), x[i]);
|
||||
}
|
||||
// Regular case.
|
||||
auto last = x.crend();
|
||||
for (int i = kNumLpcCoefficients; rtc::SafeLt(i, y.size()); ++i, --last) {
|
||||
y[i] = std::inner_product(last - kNumLpcCoefficients, last,
|
||||
lpc_coeffs.cbegin(), x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,17 +18,17 @@
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// LPC inverse filter length.
|
||||
constexpr size_t kNumLpcCoefficients = 5;
|
||||
// Linear predictive coding (LPC) inverse filter length.
|
||||
constexpr int kNumLpcCoefficients = 5;
|
||||
|
||||
// Given a frame |x|, computes a post-processed version of LPC coefficients
|
||||
// Given a frame `x`, computes a post-processed version of LPC coefficients
|
||||
// tailored for pitch estimation.
|
||||
void ComputeAndPostProcessLpcCoefficients(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs);
|
||||
|
||||
// Computes the LP residual for the input frame |x| and the LPC coefficients
|
||||
// |lpc_coeffs|. |y| and |x| can point to the same array for in-place
|
||||
// Computes the LP residual for the input frame `x` and the LPC coefficients
|
||||
// `lpc_coeffs`. `y` and `x` can point to the same array for in-place
|
||||
// computation.
|
||||
void ComputeLpResidual(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
|
||||
|
@ -1,29 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Stores pitch period and gain information. The pitch gain measures the
|
||||
// strength of the pitch (the higher, the stronger).
|
||||
struct PitchInfo {
|
||||
PitchInfo() : period(0), gain(0.f) {}
|
||||
PitchInfo(int p, float g) : period(p), gain(g) {}
|
||||
int period;
|
||||
float gain;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
|
@ -18,38 +18,52 @@
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
PitchEstimator::PitchEstimator()
|
||||
: pitch_buf_decimated_(kBufSize12kHz),
|
||||
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
|
||||
auto_corr_(kNumInvertedLags12kHz),
|
||||
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
|
||||
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
|
||||
RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size());
|
||||
}
|
||||
PitchEstimator::PitchEstimator(const AvailableCpuFeatures& cpu_features)
|
||||
: cpu_features_(cpu_features),
|
||||
y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
|
||||
pitch_buffer_12kHz_(kBufSize12kHz),
|
||||
auto_correlation_12kHz_(kNumLags12kHz) {}
|
||||
|
||||
PitchEstimator::~PitchEstimator() = default;
|
||||
|
||||
PitchInfo PitchEstimator::Estimate(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
|
||||
int PitchEstimator::Estimate(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
|
||||
rtc::ArrayView<float, kBufSize12kHz> pitch_buffer_12kHz_view(
|
||||
pitch_buffer_12kHz_.data(), kBufSize12kHz);
|
||||
RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size());
|
||||
rtc::ArrayView<float, kNumLags12kHz> auto_correlation_12kHz_view(
|
||||
auto_correlation_12kHz_.data(), kNumLags12kHz);
|
||||
RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
|
||||
auto_correlation_12kHz_view.size());
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Use `cpu_features_` to estimate pitch.
|
||||
// Perform the initial pitch search at 12 kHz.
|
||||
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
|
||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
|
||||
auto_corr_view_);
|
||||
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
|
||||
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
|
||||
// Refine the pitch period estimation.
|
||||
Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
|
||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
|
||||
auto_correlation_12kHz_view);
|
||||
CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
|
||||
pitch_buffer_12kHz_view, auto_correlation_12kHz_view, cpu_features_);
|
||||
// The refinement is done using the pitch buffer that contains 24 kHz samples.
|
||||
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
|
||||
// Therefore, adapt the inverted lags in `pitch_candidates_inv_lags` from 12
|
||||
// to 24 kHz.
|
||||
pitch_candidates_inv_lags[0] *= 2;
|
||||
pitch_candidates_inv_lags[1] *= 2;
|
||||
size_t pitch_inv_lag_48kHz =
|
||||
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags);
|
||||
// Look for stronger harmonics to find the final pitch period and its gain.
|
||||
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
|
||||
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
|
||||
return last_pitch_48kHz_;
|
||||
pitch_periods.best *= 2;
|
||||
pitch_periods.second_best *= 2;
|
||||
|
||||
// Refine the initial pitch period estimation from 12 kHz to 48 kHz.
|
||||
// Pre-compute frame energies at 24 kHz.
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
|
||||
y_energy_24kHz_.data(), kRefineNumLags24kHz);
|
||||
RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
|
||||
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view,
|
||||
cpu_features_);
|
||||
// Estimation at 48 kHz.
|
||||
const int pitch_lag_48kHz = ComputePitchPeriod48kHz(
|
||||
pitch_buffer, y_energy_24kHz_view, pitch_periods, cpu_features_);
|
||||
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
|
||||
pitch_buffer, y_energy_24kHz_view,
|
||||
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
|
||||
last_pitch_48kHz_, cpu_features_);
|
||||
return last_pitch_48kHz_.period;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
|
@ -15,10 +15,11 @@
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
#include "rtc_base/gtest_prod_util.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -26,21 +27,25 @@ namespace rnn_vad {
|
||||
// Pitch estimator.
|
||||
class PitchEstimator {
|
||||
public:
|
||||
PitchEstimator();
|
||||
explicit PitchEstimator(const AvailableCpuFeatures& cpu_features);
|
||||
PitchEstimator(const PitchEstimator&) = delete;
|
||||
PitchEstimator& operator=(const PitchEstimator&) = delete;
|
||||
~PitchEstimator();
|
||||
// Estimates the pitch period and gain. Returns the pitch estimation data for
|
||||
// 48 kHz.
|
||||
PitchInfo Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf);
|
||||
// Returns the estimated pitch period at 48 kHz.
|
||||
int Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer);
|
||||
|
||||
private:
|
||||
PitchInfo last_pitch_48kHz_;
|
||||
FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance);
|
||||
float GetLastPitchStrengthForTesting() const {
|
||||
return last_pitch_48kHz_.strength;
|
||||
}
|
||||
|
||||
const AvailableCpuFeatures cpu_features_;
|
||||
PitchInfo last_pitch_48kHz_{};
|
||||
AutoCorrelationCalculator auto_corr_calculator_;
|
||||
std::vector<float> pitch_buf_decimated_;
|
||||
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
|
||||
std::vector<float> auto_corr_;
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr_view_;
|
||||
std::vector<float> y_energy_24kHz_;
|
||||
std::vector<float> pitch_buffer_12kHz_;
|
||||
std::vector<float> auto_correlation_12kHz_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
|
@ -18,103 +18,81 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Converts a lag to an inverted lag (only for 24kHz).
|
||||
size_t GetInvertedLag(size_t lag) {
|
||||
RTC_DCHECK_LE(lag, kMaxPitch24kHz);
|
||||
return kMaxPitch24kHz - lag;
|
||||
float ComputeAutoCorrelation(
|
||||
int inverted_lag,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
const VectorMath& vector_math) {
|
||||
RTC_DCHECK_LT(inverted_lag, kBufSize24kHz);
|
||||
RTC_DCHECK_LT(inverted_lag, kRefineNumLags24kHz);
|
||||
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
|
||||
return vector_math.DotProduct(
|
||||
pitch_buffer.subview(/*offset=*/kMaxPitch24kHz),
|
||||
pitch_buffer.subview(inverted_lag, kFrameSize20ms24kHz));
|
||||
}
|
||||
|
||||
float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,
|
||||
size_t inv_lag,
|
||||
size_t max_pitch_period) {
|
||||
RTC_DCHECK_LT(inv_lag, pitch_buf.size());
|
||||
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
|
||||
RTC_DCHECK_LE(inv_lag, max_pitch_period);
|
||||
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
|
||||
return std::inner_product(pitch_buf.begin() + max_pitch_period,
|
||||
pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f);
|
||||
}
|
||||
|
||||
// Computes a pseudo-interpolation offset for an estimated pitch period |lag| by
|
||||
// looking at the auto-correlation coefficients in the neighborhood of |lag|.
|
||||
// (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output
|
||||
// is a lag in {-1, 0, +1}.
|
||||
// TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it
|
||||
// is relevant only if the spectral analysis works at a sample rate that is
|
||||
// twice as that of the pitch buffer (not so important instead for the estimated
|
||||
// pitch period feature fed into the RNN).
|
||||
int GetPitchPseudoInterpolationOffset(size_t lag,
|
||||
float prev_auto_corr,
|
||||
float lag_auto_corr,
|
||||
float next_auto_corr) {
|
||||
const float& a = prev_auto_corr;
|
||||
const float& b = lag_auto_corr;
|
||||
const float& c = next_auto_corr;
|
||||
|
||||
int offset = 0;
|
||||
if ((c - a) > 0.7f * (b - a)) {
|
||||
offset = 1; // |c| is the largest auto-correlation coefficient.
|
||||
} else if ((a - c) > 0.7f * (b - c)) {
|
||||
offset = -1; // |a| is the largest auto-correlation coefficient.
|
||||
// Given an auto-correlation coefficient `curr_auto_correlation` and its
|
||||
// neighboring values `prev_auto_correlation` and `next_auto_correlation`
|
||||
// computes a pseudo-interpolation offset to be applied to the pitch period
|
||||
// associated to `curr`. The output is a lag in {-1, 0, +1}.
|
||||
// TODO(bugs.webrtc.org/9076): Consider removing this method.
|
||||
// `GetPitchPseudoInterpolationOffset()` it is relevant only if the spectral
|
||||
// analysis works at a sample rate that is twice as that of the pitch buffer;
|
||||
// In particular, it is not relevant for the estimated pitch period feature fed
|
||||
// into the RNN.
|
||||
int GetPitchPseudoInterpolationOffset(float prev_auto_correlation,
|
||||
float curr_auto_correlation,
|
||||
float next_auto_correlation) {
|
||||
if ((next_auto_correlation - prev_auto_correlation) >
|
||||
0.7f * (curr_auto_correlation - prev_auto_correlation)) {
|
||||
return 1; // `next_auto_correlation` is the largest auto-correlation
|
||||
// coefficient.
|
||||
} else if ((prev_auto_correlation - next_auto_correlation) >
|
||||
0.7f * (curr_auto_correlation - next_auto_correlation)) {
|
||||
return -1; // `prev_auto_correlation` is the largest auto-correlation
|
||||
// coefficient.
|
||||
}
|
||||
return offset;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The
|
||||
// output sample rate is twice as that of |lag|.
|
||||
size_t PitchPseudoInterpolationLagPitchBuf(
|
||||
size_t lag,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
|
||||
// Refines a pitch period `lag` encoded as lag with pseudo-interpolation. The
|
||||
// output sample rate is twice as that of `lag`.
|
||||
int PitchPseudoInterpolationLagPitchBuf(
|
||||
int lag,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
const VectorMath& vector_math) {
|
||||
int offset = 0;
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
if (lag > 0 && lag < kMaxPitch24kHz) {
|
||||
const int inverted_lag = kMaxPitch24kHz - lag;
|
||||
offset = GetPitchPseudoInterpolationOffset(
|
||||
lag,
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1),
|
||||
kMaxPitch24kHz),
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag),
|
||||
kMaxPitch24kHz),
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1),
|
||||
kMaxPitch24kHz));
|
||||
ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer, vector_math),
|
||||
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math),
|
||||
ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer, vector_math));
|
||||
}
|
||||
return 2 * lag + offset;
|
||||
}
|
||||
|
||||
// Refines a pitch period |inv_lag| encoded as inverted lag with
|
||||
// pseudo-interpolation. The output sample rate is twice as that of
|
||||
// |inv_lag|.
|
||||
size_t PitchPseudoInterpolationInvLagAutoCorr(
|
||||
size_t inv_lag,
|
||||
rtc::ArrayView<const float> auto_corr) {
|
||||
int offset = 0;
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) {
|
||||
offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1],
|
||||
auto_corr[inv_lag],
|
||||
auto_corr[inv_lag - 1]);
|
||||
}
|
||||
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
|
||||
// be subtracted since |inv_lag| is an inverted lag but offset is a lag.
|
||||
return 2 * inv_lag + offset;
|
||||
}
|
||||
|
||||
// Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when
|
||||
// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
|
||||
// looking for sub-harmonics.
|
||||
// The values have been chosen to serve the following algorithm. Given the
|
||||
// initial pitch period T, we examine whether one of its harmonics is the true
|
||||
// fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of
|
||||
// these harmonics, in addition to the pitch gain of itself, we choose one
|
||||
// these harmonics, in addition to the pitch strength of itself, we choose one
|
||||
// multiple of its pitch period, n*T/k, to validate it (by averaging their pitch
|
||||
// gains). The multiplier n is chosen so that n*T/k is used only one time over
|
||||
// all k. When for example k = 4, we should also expect a peak at 3*T/4. When
|
||||
// k = 8 instead we don't want to look at 2*T/8, since we have already checked
|
||||
// T/4 before. Instead, we look at T*3/8.
|
||||
// strengths). The multiplier n is chosen so that n*T/k is used only one time
|
||||
// over all k. When for example k = 4, we should also expect a peak at 3*T/4.
|
||||
// When k = 8 instead we don't want to look at 2*T/8, since we have already
|
||||
// checked T/4 before. Instead, we look at T*3/8.
|
||||
// The array can be generate in Python as follows:
|
||||
// from fractions import Fraction
|
||||
// # Smallest positive integer not in X.
|
||||
@ -131,96 +109,220 @@ size_t PitchPseudoInterpolationInvLagAutoCorr(
|
||||
constexpr std::array<int, 14> kSubHarmonicMultipliers = {
|
||||
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
|
||||
|
||||
// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
|
||||
// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
|
||||
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
|
||||
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
|
||||
struct Range {
|
||||
int min;
|
||||
int max;
|
||||
};
|
||||
|
||||
// Number of analyzed pitches to the left(right) of a pitch candidate.
|
||||
constexpr int kPitchNeighborhoodRadius = 2;
|
||||
|
||||
// Creates a pitch period interval centered in `inverted_lag` with hard-coded
|
||||
// radius. Clipping is applied so that the interval is always valid for a 24 kHz
|
||||
// pitch buffer.
|
||||
Range CreateInvertedLagRange(int inverted_lag) {
|
||||
return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0),
|
||||
std::min(inverted_lag + kPitchNeighborhoodRadius,
|
||||
kInitialNumLags24kHz - 1)};
|
||||
}
|
||||
|
||||
constexpr int kNumPitchCandidates = 2; // Best and second best.
|
||||
// Maximum number of analyzed pitch periods.
|
||||
constexpr int kMaxPitchPeriods24kHz =
|
||||
kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1);
|
||||
|
||||
// Collection of inverted lags.
|
||||
class InvertedLagsIndex {
|
||||
public:
|
||||
InvertedLagsIndex() : num_entries_(0) {}
|
||||
// Adds an inverted lag to the index. Cannot add more than
|
||||
// `kMaxPitchPeriods24kHz` values.
|
||||
void Append(int inverted_lag) {
|
||||
RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz);
|
||||
inverted_lags_[num_entries_++] = inverted_lag;
|
||||
}
|
||||
const int* data() const { return inverted_lags_.data(); }
|
||||
int size() const { return num_entries_; }
|
||||
|
||||
private:
|
||||
std::array<int, kMaxPitchPeriods24kHz> inverted_lags_;
|
||||
int num_entries_;
|
||||
};
|
||||
|
||||
// Computes the auto correlation coefficients for the inverted lags in the
|
||||
// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending
|
||||
// the inverted lags for the computed auto correlation values.
|
||||
void ComputeAutoCorrelation(
|
||||
Range inverted_lags,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
|
||||
InvertedLagsIndex& inverted_lags_index,
|
||||
const VectorMath& vector_math) {
|
||||
// Check valid range.
|
||||
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
|
||||
// Trick to avoid zero initialization of `auto_correlation`.
|
||||
// Needed by the pseudo-interpolation.
|
||||
if (inverted_lags.min > 0) {
|
||||
auto_correlation[inverted_lags.min - 1] = 0.f;
|
||||
}
|
||||
if (inverted_lags.max < kInitialNumLags24kHz - 1) {
|
||||
auto_correlation[inverted_lags.max + 1] = 0.f;
|
||||
}
|
||||
// Check valid `inverted_lag` indexes.
|
||||
RTC_DCHECK_GE(inverted_lags.min, 0);
|
||||
RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz);
|
||||
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
|
||||
++inverted_lag) {
|
||||
auto_correlation[inverted_lag] =
|
||||
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math);
|
||||
inverted_lags_index.Append(inverted_lag);
|
||||
}
|
||||
}
|
||||
|
||||
// Searches the strongest pitch period at 24 kHz and returns its inverted lag at
|
||||
// 48 kHz.
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const int> inverted_lags,
|
||||
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
const VectorMath& vector_math) {
|
||||
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
|
||||
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
|
||||
int best_inverted_lag = 0; // Pitch period.
|
||||
float best_numerator = -1.f; // Pitch strength numerator.
|
||||
float best_denominator = 0.f; // Pitch strength denominator.
|
||||
for (int inverted_lag : inverted_lags) {
|
||||
// A pitch candidate must have positive correlation.
|
||||
if (auto_correlation[inverted_lag] > 0.f) {
|
||||
// Auto-correlation energy normalized by frame energy.
|
||||
const float numerator =
|
||||
auto_correlation[inverted_lag] * auto_correlation[inverted_lag];
|
||||
const float denominator = y_energy[inverted_lag];
|
||||
// Compare numerator/denominator ratios without using divisions.
|
||||
if (numerator * best_denominator > best_numerator * denominator) {
|
||||
best_inverted_lag = inverted_lag;
|
||||
best_numerator = numerator;
|
||||
best_denominator = denominator;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a
|
||||
// 48 kHz pitch period.
|
||||
if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) {
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
return best_inverted_lag * 2;
|
||||
}
|
||||
int offset = GetPitchPseudoInterpolationOffset(
|
||||
auto_correlation[best_inverted_lag + 1],
|
||||
auto_correlation[best_inverted_lag],
|
||||
auto_correlation[best_inverted_lag - 1]);
|
||||
// TODO(bugs.webrtc.org/9076): When retraining, check if `offset` below should
|
||||
// be subtracted since `inverted_lag` is an inverted lag but offset is a lag.
|
||||
return 2 * best_inverted_lag + offset;
|
||||
}
|
||||
|
||||
// Returns an alternative pitch period for `pitch_period` given a `multiplier`
|
||||
// and a `divisor` of the period.
|
||||
constexpr int GetAlternativePitchPeriod(int pitch_period,
|
||||
int multiplier,
|
||||
int divisor) {
|
||||
RTC_DCHECK_GT(divisor, 0);
|
||||
// Same as `round(multiplier * pitch_period / divisor)`.
|
||||
return (2 * multiplier * pitch_period + divisor) / (2 * divisor);
|
||||
}
|
||||
|
||||
// Returns true if the alternative pitch period is stronger than the initial one
|
||||
// given the last estimated pitch and the value of `period_divisor` used to
|
||||
// compute the alternative pitch period via `GetAlternativePitchPeriod()`.
|
||||
bool IsAlternativePitchStrongerThanInitial(PitchInfo last,
|
||||
PitchInfo initial,
|
||||
PitchInfo alternative,
|
||||
int period_divisor) {
|
||||
// Initial pitch period candidate thresholds for a sample rate of 24 kHz.
|
||||
// Computed as [5*k*k for k in range(16)].
|
||||
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
|
||||
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
|
||||
static_assert(
|
||||
kInitialPitchPeriodThresholds.size() == kSubHarmonicMultipliers.size(),
|
||||
"");
|
||||
RTC_DCHECK_GE(last.period, 0);
|
||||
RTC_DCHECK_GE(initial.period, 0);
|
||||
RTC_DCHECK_GE(alternative.period, 0);
|
||||
RTC_DCHECK_GE(period_divisor, 2);
|
||||
// Compute a term that lowers the threshold when `alternative.period` is close
|
||||
// to the last estimated period `last.period` - i.e., pitch tracking.
|
||||
float lower_threshold_term = 0.f;
|
||||
if (std::abs(alternative.period - last.period) <= 1) {
|
||||
// The candidate pitch period is within 1 sample from the last one.
|
||||
// Make the candidate at `alternative.period` very easy to be accepted.
|
||||
lower_threshold_term = last.strength;
|
||||
} else if (std::abs(alternative.period - last.period) == 2 &&
|
||||
initial.period >
|
||||
kInitialPitchPeriodThresholds[period_divisor - 2]) {
|
||||
// The candidate pitch period is 2 samples far from the last one and the
|
||||
// period `initial.period` (from which `alternative.period` has been
|
||||
// derived) is greater than a threshold. Make `alternative.period` easy to
|
||||
// be accepted.
|
||||
lower_threshold_term = 0.5f * last.strength;
|
||||
}
|
||||
// Set the threshold based on the strength of the initial estimate
|
||||
// `initial.period`. Also reduce the chance of false positives caused by a
|
||||
// bias towards high frequencies (originating from short-term correlations).
|
||||
float threshold =
|
||||
std::max(0.3f, 0.7f * initial.strength - lower_threshold_term);
|
||||
if (alternative.period < 3 * kMinPitch24kHz) {
|
||||
// High frequency.
|
||||
threshold = std::max(0.4f, 0.85f * initial.strength - lower_threshold_term);
|
||||
} else if (alternative.period < 2 * kMinPitch24kHz) {
|
||||
// Even higher frequency.
|
||||
threshold = std::max(0.5f, 0.9f * initial.strength - lower_threshold_term);
|
||||
}
|
||||
return alternative.strength > threshold;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
rtc::ArrayView<float, kBufSize12kHz> dst) {
|
||||
// TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
|
||||
static_assert(2 * dst.size() == src.size(), "");
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
static_assert(2 * kBufSize12kHz == kBufSize24kHz, "");
|
||||
for (int i = 0; i < kBufSize12kHz; ++i) {
|
||||
dst[i] = src[2 * i];
|
||||
}
|
||||
}
|
||||
|
||||
float ComputePitchGainThreshold(int candidate_pitch_period,
|
||||
int pitch_period_ratio,
|
||||
int initial_pitch_period,
|
||||
float initial_pitch_gain,
|
||||
int prev_pitch_period,
|
||||
float prev_pitch_gain) {
|
||||
// Map arguments to more compact aliases.
|
||||
const int& t1 = candidate_pitch_period;
|
||||
const int& k = pitch_period_ratio;
|
||||
const int& t0 = initial_pitch_period;
|
||||
const float& g0 = initial_pitch_gain;
|
||||
const int& t_prev = prev_pitch_period;
|
||||
const float& g_prev = prev_pitch_gain;
|
||||
|
||||
// Validate input.
|
||||
RTC_DCHECK_GE(t1, 0);
|
||||
RTC_DCHECK_GE(k, 2);
|
||||
RTC_DCHECK_GE(t0, 0);
|
||||
RTC_DCHECK_GE(t_prev, 0);
|
||||
|
||||
// Compute a term that lowers the threshold when |t1| is close to the last
|
||||
// estimated period |t_prev| - i.e., pitch tracking.
|
||||
float lower_threshold_term = 0;
|
||||
if (abs(t1 - t_prev) <= 1) {
|
||||
// The candidate pitch period is within 1 sample from the previous one.
|
||||
// Make the candidate at |t1| very easy to be accepted.
|
||||
lower_threshold_term = g_prev;
|
||||
} else if (abs(t1 - t_prev) == 2 &&
|
||||
t0 > kInitialPitchPeriodThresholds[k - 2]) {
|
||||
// The candidate pitch period is 2 samples far from the previous one and the
|
||||
// period |t0| (from which |t1| has been derived) is greater than a
|
||||
// threshold. Make |t1| easy to be accepted.
|
||||
lower_threshold_term = 0.5f * g_prev;
|
||||
}
|
||||
// Set the threshold based on the gain of the initial estimate |t0|. Also
|
||||
// reduce the chance of false positives caused by a bias towards high
|
||||
// frequencies (originating from short-term correlations).
|
||||
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
|
||||
if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
|
||||
// High frequency.
|
||||
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
|
||||
} else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
|
||||
// Even higher frequency.
|
||||
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
|
||||
}
|
||||
return threshold;
|
||||
}
|
||||
|
||||
void ComputeSlidingFrameSquareEnergies(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) {
|
||||
float yy =
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz);
|
||||
yy_values[0] = yy;
|
||||
for (size_t i = 1; i < yy_values.size(); ++i) {
|
||||
RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz);
|
||||
RTC_DCHECK_LE(i, kMaxPitch24kHz);
|
||||
const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i];
|
||||
const float new_coeff = pitch_buf[kMaxPitch24kHz - i];
|
||||
yy -= old_coeff * old_coeff;
|
||||
yy += new_coeff * new_coeff;
|
||||
yy = std::max(0.f, yy);
|
||||
yy_values[i] = yy;
|
||||
void ComputeSlidingFrameSquareEnergies24kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
VectorMath vector_math(cpu_features);
|
||||
static_assert(kFrameSize20ms24kHz < kBufSize24kHz, "");
|
||||
const auto frame_20ms_view = pitch_buffer.subview(0, kFrameSize20ms24kHz);
|
||||
float yy = vector_math.DotProduct(frame_20ms_view, frame_20ms_view);
|
||||
y_energy[0] = yy;
|
||||
static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, "");
|
||||
static_assert(kMaxPitch24kHz < kRefineNumLags24kHz, "");
|
||||
for (int inverted_lag = 0; inverted_lag < kMaxPitch24kHz; ++inverted_lag) {
|
||||
yy -= pitch_buffer[inverted_lag] * pitch_buffer[inverted_lag];
|
||||
yy += pitch_buffer[inverted_lag + kFrameSize20ms24kHz] *
|
||||
pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
|
||||
yy = std::max(1.f, yy);
|
||||
y_energy[inverted_lag + 1] = yy;
|
||||
}
|
||||
}
|
||||
|
||||
std::array<size_t, 2> FindBestPitchPeriods(
|
||||
rtc::ArrayView<const float> auto_corr,
|
||||
rtc::ArrayView<const float> pitch_buf,
|
||||
size_t max_pitch_period) {
|
||||
CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
static_assert(kMaxPitch12kHz > kNumLags12kHz, "");
|
||||
static_assert(kMaxPitch12kHz < kBufSize12kHz, "");
|
||||
|
||||
// Stores a pitch candidate period and strength information.
|
||||
struct PitchCandidate {
|
||||
// Pitch period encoded as inverted lag.
|
||||
size_t period_inverted_lag = 0;
|
||||
int period_inverted_lag = 0;
|
||||
// Pitch strength encoded as a ratio.
|
||||
float strength_numerator = -1.f;
|
||||
float strength_denominator = 0.f;
|
||||
@ -232,25 +334,22 @@ std::array<size_t, 2> FindBestPitchPeriods(
|
||||
}
|
||||
};
|
||||
|
||||
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
|
||||
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
|
||||
const size_t frame_size = pitch_buf.size() - max_pitch_period;
|
||||
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
|
||||
float yy =
|
||||
std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1,
|
||||
pitch_buf.begin(), 1.f);
|
||||
VectorMath vector_math(cpu_features);
|
||||
static_assert(kFrameSize20ms12kHz + 1 < kBufSize12kHz, "");
|
||||
const auto frame_view = pitch_buffer.subview(0, kFrameSize20ms12kHz + 1);
|
||||
float denominator = 1.f + vector_math.DotProduct(frame_view, frame_view);
|
||||
// Search best and second best pitches by looking at the scaled
|
||||
// auto-correlation.
|
||||
PitchCandidate candidate;
|
||||
PitchCandidate best;
|
||||
PitchCandidate second_best;
|
||||
second_best.period_inverted_lag = 1;
|
||||
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
|
||||
for (int inverted_lag = 0; inverted_lag < kNumLags12kHz; ++inverted_lag) {
|
||||
// A pitch candidate must have positive correlation.
|
||||
if (auto_corr[inv_lag] > 0) {
|
||||
candidate.period_inverted_lag = inv_lag;
|
||||
candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag];
|
||||
candidate.strength_denominator = yy;
|
||||
if (auto_correlation[inverted_lag] > 0.f) {
|
||||
PitchCandidate candidate{
|
||||
inverted_lag,
|
||||
auto_correlation[inverted_lag] * auto_correlation[inverted_lag],
|
||||
denominator};
|
||||
if (candidate.HasStrongerPitchThan(second_best)) {
|
||||
if (candidate.HasStrongerPitchThan(best)) {
|
||||
second_best = best;
|
||||
@ -260,143 +359,154 @@ std::array<size_t, 2> FindBestPitchPeriods(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update |squared_energy_y| for the next inverted lag.
|
||||
const float old_coeff = pitch_buf[inv_lag];
|
||||
const float new_coeff = pitch_buf[inv_lag + frame_size];
|
||||
yy -= old_coeff * old_coeff;
|
||||
yy += new_coeff * new_coeff;
|
||||
yy = std::max(0.f, yy);
|
||||
// Update `squared_energy_y` for the next inverted lag.
|
||||
const float y_old = pitch_buffer[inverted_lag];
|
||||
const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms12kHz];
|
||||
denominator -= y_old * y_old;
|
||||
denominator += y_new * y_new;
|
||||
denominator = std::max(0.f, denominator);
|
||||
}
|
||||
return {{best.period_inverted_lag, second_best.period_inverted_lag}};
|
||||
return {best.period_inverted_lag, second_best.period_inverted_lag};
|
||||
}
|
||||
|
||||
size_t RefinePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
rtc::ArrayView<const size_t, 2> inv_lags) {
|
||||
// Compute the auto-correlation terms only for neighbors of the given pitch
|
||||
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
|
||||
// for a few lag values).
|
||||
std::array<float, kNumInvertedLags24kHz> auto_corr;
|
||||
auto_corr.fill(0.f); // Zeros become ignored lags in FindBestPitchPeriods().
|
||||
auto is_neighbor = [](size_t i, size_t j) {
|
||||
return ((i > j) ? (i - j) : (j - i)) <= 2;
|
||||
};
|
||||
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
|
||||
if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1]))
|
||||
auto_corr[inv_lag] =
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz);
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
CandidatePitchPeriods pitch_candidates,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
// Compute the auto-correlation terms only for neighbors of the two pitch
|
||||
// candidates (best and second best).
|
||||
std::array<float, kInitialNumLags24kHz> auto_correlation;
|
||||
InvertedLagsIndex inverted_lags_index;
|
||||
// Create two inverted lag ranges so that `r1` precedes `r2`.
|
||||
const bool swap_candidates =
|
||||
pitch_candidates.best > pitch_candidates.second_best;
|
||||
const Range r1 = CreateInvertedLagRange(
|
||||
swap_candidates ? pitch_candidates.second_best : pitch_candidates.best);
|
||||
const Range r2 = CreateInvertedLagRange(
|
||||
swap_candidates ? pitch_candidates.best : pitch_candidates.second_best);
|
||||
// Check valid ranges.
|
||||
RTC_DCHECK_LE(r1.min, r1.max);
|
||||
RTC_DCHECK_LE(r2.min, r2.max);
|
||||
// Check `r1` precedes `r2`.
|
||||
RTC_DCHECK_LE(r1.min, r2.min);
|
||||
RTC_DCHECK_LE(r1.max, r2.max);
|
||||
VectorMath vector_math(cpu_features);
|
||||
if (r1.max + 1 >= r2.min) {
|
||||
// Overlapping or adjacent ranges.
|
||||
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index, vector_math);
|
||||
} else {
|
||||
// Disjoint ranges.
|
||||
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index, vector_math);
|
||||
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index, vector_math);
|
||||
}
|
||||
// Find best pitch at 24 kHz.
|
||||
const auto pitch_candidates_inv_lags = FindBestPitchPeriods(
|
||||
{auto_corr.data(), auto_corr.size()},
|
||||
{pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz);
|
||||
const auto inv_lag = pitch_candidates_inv_lags[0]; // Refine the best.
|
||||
// Pseudo-interpolation.
|
||||
return PitchPseudoInterpolationInvLagAutoCorr(inv_lag, auto_corr);
|
||||
return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
|
||||
auto_correlation, y_energy, vector_math);
|
||||
}
|
||||
|
||||
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo prev_pitch_48kHz) {
|
||||
PitchInfo last_pitch_48kHz,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
|
||||
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
|
||||
|
||||
// Stores information for a refined pitch candidate.
|
||||
struct RefinedPitchCandidate {
|
||||
RefinedPitchCandidate() {}
|
||||
RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
|
||||
: period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
|
||||
int period_24kHz;
|
||||
// Pitch strength information.
|
||||
float gain;
|
||||
// Additional pitch strength information used for the final estimation of
|
||||
// pitch gain.
|
||||
float xy; // Cross-correlation.
|
||||
float yy; // Auto-correlation.
|
||||
int period;
|
||||
float strength;
|
||||
// Additional strength data used for the final pitch estimation.
|
||||
float xy; // Auto-correlation.
|
||||
float y_energy; // Energy of the sliding frame `y`.
|
||||
};
|
||||
|
||||
// Initialize.
|
||||
std::array<float, kMaxPitch24kHz + 1> yy_values;
|
||||
ComputeSlidingFrameSquareEnergies(pitch_buf,
|
||||
{yy_values.data(), yy_values.size()});
|
||||
const float xx = yy_values[0];
|
||||
// Helper lambdas.
|
||||
const auto pitch_gain = [](float xy, float yy, float xx) {
|
||||
RTC_DCHECK_LE(0.f, xx * yy);
|
||||
return xy / std::sqrt(1.f + xx * yy);
|
||||
const float x_energy = y_energy[kMaxPitch24kHz];
|
||||
const auto pitch_strength = [x_energy](float xy, float y_energy) {
|
||||
RTC_DCHECK_GE(x_energy * y_energy, 0.f);
|
||||
return xy / std::sqrt(1.f + x_energy * y_energy);
|
||||
};
|
||||
// Initial pitch candidate gain.
|
||||
VectorMath vector_math(cpu_features);
|
||||
|
||||
// Initialize the best pitch candidate with `initial_pitch_period_48kHz`.
|
||||
RefinedPitchCandidate best_pitch;
|
||||
best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
|
||||
static_cast<int>(kMaxPitch24kHz - 1));
|
||||
best_pitch.xy = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
|
||||
best_pitch.yy = yy_values[best_pitch.period_24kHz];
|
||||
best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx);
|
||||
best_pitch.period =
|
||||
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
|
||||
best_pitch.xy = ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period,
|
||||
pitch_buffer, vector_math);
|
||||
best_pitch.y_energy = y_energy[kMaxPitch24kHz - best_pitch.period];
|
||||
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy);
|
||||
// Keep a copy of the initial pitch candidate.
|
||||
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
|
||||
// 24 kHz version of the last estimated pitch.
|
||||
const PitchInfo last_pitch{last_pitch_48kHz.period / 2,
|
||||
last_pitch_48kHz.strength};
|
||||
|
||||
// Store the initial pitch period information.
|
||||
const size_t initial_pitch_period = best_pitch.period_24kHz;
|
||||
const float initial_pitch_gain = best_pitch.gain;
|
||||
|
||||
// Given the initial pitch estimation, check lower periods (i.e., harmonics).
|
||||
const auto alternative_period = [](int period, int k, int n) -> int {
|
||||
RTC_DCHECK_GT(k, 0);
|
||||
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
|
||||
};
|
||||
for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
|
||||
++k) {
|
||||
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
|
||||
if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
|
||||
break;
|
||||
// Find `max_period_divisor` such that the result of
|
||||
// `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)`
|
||||
// equals `kMinPitch24kHz`.
|
||||
const int max_period_divisor =
|
||||
(2 * initial_pitch.period) / (2 * kMinPitch24kHz - 1);
|
||||
for (int period_divisor = 2; period_divisor <= max_period_divisor;
|
||||
++period_divisor) {
|
||||
PitchInfo alternative_pitch;
|
||||
alternative_pitch.period = GetAlternativePitchPeriod(
|
||||
initial_pitch.period, /*multiplier=*/1, period_divisor);
|
||||
RTC_DCHECK_GE(alternative_pitch.period, kMinPitch24kHz);
|
||||
// When looking at `alternative_pitch.period`, we also look at one of its
|
||||
// sub-harmonics. `kSubHarmonicMultipliers` is used to know where to look.
|
||||
// `period_divisor` == 2 is a special case since `dual_alternative_period`
|
||||
// might be greater than the maximum pitch period.
|
||||
int dual_alternative_period = GetAlternativePitchPeriod(
|
||||
initial_pitch.period, kSubHarmonicMultipliers[period_divisor - 2],
|
||||
period_divisor);
|
||||
RTC_DCHECK_GT(dual_alternative_period, 0);
|
||||
if (period_divisor == 2 && dual_alternative_period > kMaxPitch24kHz) {
|
||||
dual_alternative_period = initial_pitch.period;
|
||||
}
|
||||
// When looking at |candidate_pitch_period|, we also look at one of its
|
||||
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
|
||||
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
|
||||
// be greater than the maximum pitch period.
|
||||
int candidate_pitch_secondary_period = alternative_period(
|
||||
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
|
||||
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
|
||||
if (k == 2 &&
|
||||
candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
|
||||
candidate_pitch_secondary_period = initial_pitch_period;
|
||||
}
|
||||
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
|
||||
RTC_DCHECK_NE(alternative_pitch.period, dual_alternative_period)
|
||||
<< "The lower pitch period and the additional sub-harmonic must not "
|
||||
"coincide.";
|
||||
// Compute an auto-correlation score for the primary pitch candidate
|
||||
// |candidate_pitch_period| by also looking at its possible sub-harmonic
|
||||
// |candidate_pitch_secondary_period|.
|
||||
float xy_primary_period = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz);
|
||||
float xy_secondary_period = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, GetInvertedLag(candidate_pitch_secondary_period),
|
||||
kMaxPitch24kHz);
|
||||
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
|
||||
float yy = 0.5f * (yy_values[candidate_pitch_period] +
|
||||
yy_values[candidate_pitch_secondary_period]);
|
||||
float candidate_pitch_gain = pitch_gain(xy, yy, xx);
|
||||
// `alternative_pitch.period` by also looking at its possible sub-harmonic
|
||||
// `dual_alternative_period`.
|
||||
const float xy_primary_period = ComputeAutoCorrelation(
|
||||
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer, vector_math);
|
||||
// TODO(webrtc:10480): Copy `xy_primary_period` if the secondary period is
|
||||
// equal to the primary one.
|
||||
const float xy_secondary_period = ComputeAutoCorrelation(
|
||||
kMaxPitch24kHz - dual_alternative_period, pitch_buffer, vector_math);
|
||||
const float xy = 0.5f * (xy_primary_period + xy_secondary_period);
|
||||
const float yy =
|
||||
0.5f * (y_energy[kMaxPitch24kHz - alternative_pitch.period] +
|
||||
y_energy[kMaxPitch24kHz - dual_alternative_period]);
|
||||
alternative_pitch.strength = pitch_strength(xy, yy);
|
||||
|
||||
// Maybe update best period.
|
||||
float threshold = ComputePitchGainThreshold(
|
||||
candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain,
|
||||
prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain);
|
||||
if (candidate_pitch_gain > threshold) {
|
||||
best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy};
|
||||
if (IsAlternativePitchStrongerThanInitial(
|
||||
last_pitch, initial_pitch, alternative_pitch, period_divisor)) {
|
||||
best_pitch = {alternative_pitch.period, alternative_pitch.strength, xy,
|
||||
yy};
|
||||
}
|
||||
}
|
||||
|
||||
// Final pitch gain and period.
|
||||
// Final pitch strength and period.
|
||||
best_pitch.xy = std::max(0.f, best_pitch.xy);
|
||||
RTC_DCHECK_LE(0.f, best_pitch.yy);
|
||||
float final_pitch_gain = (best_pitch.yy <= best_pitch.xy)
|
||||
? 1.f
|
||||
: best_pitch.xy / (best_pitch.yy + 1.f);
|
||||
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
|
||||
RTC_DCHECK_LE(0.f, best_pitch.y_energy);
|
||||
float final_pitch_strength =
|
||||
(best_pitch.y_energy <= best_pitch.xy)
|
||||
? 1.f
|
||||
: best_pitch.xy / (best_pitch.y_energy + 1.f);
|
||||
final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength);
|
||||
int final_pitch_period_48kHz = std::max(
|
||||
kMinPitch48kHz,
|
||||
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
|
||||
kMinPitch48kHz, PitchPseudoInterpolationLagPitchBuf(
|
||||
best_pitch.period, pitch_buffer, vector_math));
|
||||
|
||||
return {final_pitch_period_48kHz, final_pitch_gain};
|
||||
return {final_pitch_period_48kHz, final_pitch_strength};
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
|
@ -14,10 +14,11 @@
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -26,50 +27,86 @@ namespace rnn_vad {
|
||||
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
rtc::ArrayView<float, kBufSize12kHz> dst);
|
||||
|
||||
// Computes a gain threshold for a candidate pitch period given the initial and
|
||||
// the previous pitch period and gain estimates and the pitch period ratio used
|
||||
// to derive the candidate pitch period from the initial period.
|
||||
float ComputePitchGainThreshold(int candidate_pitch_period,
|
||||
int pitch_period_ratio,
|
||||
int initial_pitch_period,
|
||||
float initial_pitch_gain,
|
||||
int prev_pitch_period,
|
||||
float prev_pitch_gain);
|
||||
|
||||
// Computes the sum of squared samples for every sliding frame in the pitch
|
||||
// buffer. |yy_values| indexes are lags.
|
||||
// Key concepts and keywords used below in this file.
|
||||
//
|
||||
// The pitch buffer is structured as depicted below:
|
||||
// |.........|...........|
|
||||
// a b
|
||||
// The part on the left, named "a" contains the oldest samples, whereas "b" the
|
||||
// most recent ones. The size of "a" corresponds to the maximum pitch period,
|
||||
// that of "b" to the frame size (e.g., 16 ms and 20 ms respectively).
|
||||
void ComputeSlidingFrameSquareEnergies(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values);
|
||||
// The pitch estimation relies on a pitch buffer, which is an array-like data
|
||||
// structured designed as follows:
|
||||
//
|
||||
// |....A....|.....B.....|
|
||||
//
|
||||
// The part on the left, named `A` contains the oldest samples, whereas `B`
|
||||
// contains the most recent ones. The size of `A` corresponds to the maximum
|
||||
// pitch period, that of `B` to the analysis frame size (e.g., 16 ms and 20 ms
|
||||
// respectively).
|
||||
//
|
||||
// Pitch estimation is essentially based on the analysis of two 20 ms frames
|
||||
// extracted from the pitch buffer. One frame, called `x`, is kept fixed and
|
||||
// corresponds to `B` - i.e., the most recent 20 ms. The other frame, called
|
||||
// `y`, is extracted from different parts of the buffer instead.
|
||||
//
|
||||
// The offset between `x` and `y` corresponds to a specific pitch period.
|
||||
// For instance, if `y` is positioned at the beginning of the pitch buffer, then
|
||||
// the cross-correlation between `x` and `y` can be used as an indication of the
|
||||
// strength for the maximum pitch.
|
||||
//
|
||||
// Such an offset can be encoded in two ways:
|
||||
// - As a lag, which is the index in the pitch buffer for the first item in `y`
|
||||
// - As an inverted lag, which is the number of samples from the beginning of
|
||||
// `x` and the end of `y`
|
||||
//
|
||||
// |---->| lag
|
||||
// |....A....|.....B.....|
|
||||
// |<--| inverted lag
|
||||
// |.....y.....| `y` 20 ms frame
|
||||
//
|
||||
// The inverted lag has the advantage of being directly proportional to the
|
||||
// corresponding pitch period.
|
||||
|
||||
// Given the auto-correlation coefficients stored according to
|
||||
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
|
||||
// and the second best pitch periods.
|
||||
std::array<size_t, 2> FindBestPitchPeriods(
|
||||
rtc::ArrayView<const float> auto_corr,
|
||||
rtc::ArrayView<const float> pitch_buf,
|
||||
size_t max_pitch_period);
|
||||
// Computes the sum of squared samples for every sliding frame `y` in the pitch
|
||||
// buffer. The indexes of `y_energy` are inverted lags.
|
||||
void ComputeSlidingFrameSquareEnergies24kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
// Refines the pitch period estimation given the pitch buffer |pitch_buf| and
|
||||
// the initial pitch period estimation |inv_lags|. Returns an inverted lag at
|
||||
// 48 kHz.
|
||||
size_t RefinePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
rtc::ArrayView<const size_t, 2> inv_lags);
|
||||
// Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags.
|
||||
struct CandidatePitchPeriods {
|
||||
int best;
|
||||
int second_best;
|
||||
};
|
||||
|
||||
// Refines the pitch period estimation and compute the pitch gain. Returns the
|
||||
// refined pitch estimation data at 48 kHz.
|
||||
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
// Computes the candidate pitch periods at 12 kHz given a view on the 12 kHz
|
||||
// pitch buffer and the auto-correlation values (having inverted lags as
|
||||
// indexes).
|
||||
CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer,
|
||||
// the energies for the sliding frames `y` at 24 kHz and the pitch period
|
||||
// candidates at 24 kHz (encoded as inverted lag).
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
CandidatePitchPeriods pitch_candidates_24kHz,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
struct PitchInfo {
|
||||
int period;
|
||||
float strength;
|
||||
};
|
||||
|
||||
// Computes the pitch period at 48 kHz searching in an extended pitch range
|
||||
// given a view on the 24 kHz pitch buffer, the energies for the sliding frames
|
||||
// `y` at 24 kHz, the initial 48 kHz estimation (computed by
|
||||
// `ComputePitchPeriod48kHz()`) and the last estimated pitch.
|
||||
PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo prev_pitch_48kHz);
|
||||
PitchInfo last_pitch_48kHz,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -21,7 +21,7 @@ namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Ring buffer for N arrays of type T each one with size S.
|
||||
template <typename T, size_t S, size_t N>
|
||||
template <typename T, int S, int N>
|
||||
class RingBuffer {
|
||||
static_assert(S > 0, "");
|
||||
static_assert(N > 0, "");
|
||||
@ -35,7 +35,7 @@ class RingBuffer {
|
||||
~RingBuffer() = default;
|
||||
// Set the ring buffer values to zero.
|
||||
void Reset() { buffer_.fill(0); }
|
||||
// Replace the least recently pushed array in the buffer with |new_values|.
|
||||
// Replace the least recently pushed array in the buffer with `new_values`.
|
||||
void Push(rtc::ArrayView<const T, S> new_values) {
|
||||
std::memcpy(buffer_.data() + S * tail_, new_values.data(), S * sizeof(T));
|
||||
tail_ += 1;
|
||||
@ -43,13 +43,12 @@ class RingBuffer {
|
||||
tail_ = 0;
|
||||
}
|
||||
// Return an array view onto the array with a given delay. A view on the last
|
||||
// and least recently push array is returned when |delay| is 0 and N - 1
|
||||
// and least recently push array is returned when `delay` is 0 and N - 1
|
||||
// respectively.
|
||||
rtc::ArrayView<const T, S> GetArrayView(size_t delay) const {
|
||||
const int delay_int = static_cast<int>(delay);
|
||||
RTC_DCHECK_LE(0, delay_int);
|
||||
RTC_DCHECK_LT(delay_int, N);
|
||||
int offset = tail_ - 1 - delay_int;
|
||||
rtc::ArrayView<const T, S> GetArrayView(int delay) const {
|
||||
RTC_DCHECK_LE(0, delay);
|
||||
RTC_DCHECK_LT(delay, N);
|
||||
int offset = tail_ - 1 - delay;
|
||||
if (offset < 0)
|
||||
offset += N;
|
||||
return {buffer_.data() + S * offset, S};
|
||||
|
@ -10,415 +10,81 @@
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
|
||||
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
#include <emmintrin.h>
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
using rnnoise::kWeightsScale;
|
||||
|
||||
using rnnoise::kInputLayerInputSize;
|
||||
using ::rnnoise::kInputLayerInputSize;
|
||||
static_assert(kFeatureVectorSize == kInputLayerInputSize, "");
|
||||
using rnnoise::kInputDenseBias;
|
||||
using rnnoise::kInputDenseWeights;
|
||||
using rnnoise::kInputLayerOutputSize;
|
||||
static_assert(kInputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
|
||||
"Increase kFullyConnectedLayersMaxUnits.");
|
||||
using ::rnnoise::kInputDenseBias;
|
||||
using ::rnnoise::kInputDenseWeights;
|
||||
using ::rnnoise::kInputLayerOutputSize;
|
||||
static_assert(kInputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
|
||||
|
||||
using rnnoise::kHiddenGruBias;
|
||||
using rnnoise::kHiddenGruRecurrentWeights;
|
||||
using rnnoise::kHiddenGruWeights;
|
||||
using rnnoise::kHiddenLayerOutputSize;
|
||||
static_assert(kHiddenLayerOutputSize <= kRecurrentLayersMaxUnits,
|
||||
"Increase kRecurrentLayersMaxUnits.");
|
||||
using ::rnnoise::kHiddenGruBias;
|
||||
using ::rnnoise::kHiddenGruRecurrentWeights;
|
||||
using ::rnnoise::kHiddenGruWeights;
|
||||
using ::rnnoise::kHiddenLayerOutputSize;
|
||||
static_assert(kHiddenLayerOutputSize <= kGruLayerMaxUnits, "");
|
||||
|
||||
using rnnoise::kOutputDenseBias;
|
||||
using rnnoise::kOutputDenseWeights;
|
||||
using rnnoise::kOutputLayerOutputSize;
|
||||
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
|
||||
"Increase kFullyConnectedLayersMaxUnits.");
|
||||
|
||||
using rnnoise::SigmoidApproximated;
|
||||
using rnnoise::TansigApproximated;
|
||||
|
||||
inline float RectifiedLinearUnit(float x) {
|
||||
return x < 0.f ? 0.f : x;
|
||||
}
|
||||
|
||||
std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
|
||||
std::vector<float> scaled_params(params.size());
|
||||
std::transform(params.begin(), params.end(), scaled_params.begin(),
|
||||
[](int8_t x) -> float {
|
||||
return rnnoise::kWeightsScale * static_cast<float>(x);
|
||||
});
|
||||
return scaled_params;
|
||||
}
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this
|
||||
// function to improve setup time.
|
||||
// Casts and scales |weights| and re-arranges the layout.
|
||||
std::vector<float> GetPreprocessedFcWeights(
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
size_t output_size) {
|
||||
if (output_size == 1) {
|
||||
return GetScaledParams(weights);
|
||||
}
|
||||
// Transpose, scale and cast.
|
||||
const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size);
|
||||
std::vector<float> w(weights.size());
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
w[o * input_size + i] = rnnoise::kWeightsScale *
|
||||
static_cast<float>(weights[i * output_size + o]);
|
||||
}
|
||||
}
|
||||
return w;
|
||||
}
|
||||
|
||||
constexpr size_t kNumGruGates = 3; // Update, reset, output.
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this
|
||||
// function to improve setup time.
|
||||
// Casts and scales |tensor_src| for a GRU layer and re-arranges the layout.
|
||||
// It works both for weights, recurrent weights and bias.
|
||||
std::vector<float> GetPreprocessedGruTensor(
|
||||
rtc::ArrayView<const int8_t> tensor_src,
|
||||
size_t output_size) {
|
||||
// Transpose, cast and scale.
|
||||
// |n| is the size of the first dimension of the 3-dim tensor |weights|.
|
||||
const size_t n =
|
||||
rtc::CheckedDivExact(tensor_src.size(), output_size * kNumGruGates);
|
||||
const size_t stride_src = kNumGruGates * output_size;
|
||||
const size_t stride_dst = n * output_size;
|
||||
std::vector<float> tensor_dst(tensor_src.size());
|
||||
for (size_t g = 0; g < kNumGruGates; ++g) {
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
tensor_dst[g * stride_dst + o * n + i] =
|
||||
rnnoise::kWeightsScale *
|
||||
static_cast<float>(
|
||||
tensor_src[i * stride_src + g * output_size + o]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tensor_dst;
|
||||
}
|
||||
|
||||
void ComputeGruUpdateResetGates(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> state,
|
||||
rtc::ArrayView<float> gate) {
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
gate[o] = bias[o];
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
gate[o] += input[i] * weights[o * input_size + i];
|
||||
}
|
||||
for (size_t s = 0; s < output_size; ++s) {
|
||||
gate[o] += state[s] * recurrent_weights[o * output_size + s];
|
||||
}
|
||||
gate[o] = SigmoidApproximated(gate[o]);
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeGruOutputGate(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> state,
|
||||
rtc::ArrayView<const float> reset,
|
||||
rtc::ArrayView<float> gate) {
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
gate[o] = bias[o];
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
gate[o] += input[i] * weights[o * input_size + i];
|
||||
}
|
||||
for (size_t s = 0; s < output_size; ++s) {
|
||||
gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
|
||||
}
|
||||
gate[o] = RectifiedLinearUnit(gate[o]);
|
||||
}
|
||||
}
|
||||
|
||||
// Gated recurrent unit (GRU) layer un-optimized implementation.
|
||||
void ComputeGruLayerOutput(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<float> state) {
|
||||
RTC_DCHECK_EQ(input_size, input.size());
|
||||
// Stride and offset used to read parameter arrays.
|
||||
const size_t stride_in = input_size * output_size;
|
||||
const size_t stride_out = output_size * output_size;
|
||||
|
||||
// Update gate.
|
||||
std::array<float, kRecurrentLayersMaxUnits> update;
|
||||
ComputeGruUpdateResetGates(
|
||||
input_size, output_size, weights.subview(0, stride_in),
|
||||
recurrent_weights.subview(0, stride_out), bias.subview(0, output_size),
|
||||
input, state, update);
|
||||
|
||||
// Reset gate.
|
||||
std::array<float, kRecurrentLayersMaxUnits> reset;
|
||||
ComputeGruUpdateResetGates(
|
||||
input_size, output_size, weights.subview(stride_in, stride_in),
|
||||
recurrent_weights.subview(stride_out, stride_out),
|
||||
bias.subview(output_size, output_size), input, state, reset);
|
||||
|
||||
// Output gate.
|
||||
std::array<float, kRecurrentLayersMaxUnits> output;
|
||||
ComputeGruOutputGate(
|
||||
input_size, output_size, weights.subview(2 * stride_in, stride_in),
|
||||
recurrent_weights.subview(2 * stride_out, stride_out),
|
||||
bias.subview(2 * output_size, output_size), input, state, reset, output);
|
||||
|
||||
// Update output through the update gates and update the state.
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
output[o] = update[o] * state[o] + (1.f - update[o]) * output[o];
|
||||
state[o] = output[o];
|
||||
}
|
||||
}
|
||||
|
||||
// Fully connected layer un-optimized implementation.
|
||||
void ComputeFullyConnectedLayerOutput(
|
||||
size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
rtc::ArrayView<float> output) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
output[o] = bias[o];
|
||||
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
|
||||
// |weights_| change the performance across different platforms.
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
output[o] += input[i] * weights[o * input_size + i];
|
||||
}
|
||||
output[o] = activation_function(output[o]);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
// Fully connected layer SSE2 implementation.
|
||||
void ComputeFullyConnectedLayerOutputSse2(
|
||||
size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
rtc::ArrayView<float> output) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
const size_t input_size_by_4 = input_size >> 2;
|
||||
const size_t offset = input_size & ~3;
|
||||
__m128 sum_wx_128;
|
||||
const float* v = reinterpret_cast<const float*>(&sum_wx_128);
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
// Perform 128 bit vector operations.
|
||||
sum_wx_128 = _mm_set1_ps(0);
|
||||
const float* x_p = input.data();
|
||||
const float* w_p = weights.data() + o * input_size;
|
||||
for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
|
||||
sum_wx_128 = _mm_add_ps(sum_wx_128,
|
||||
_mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p)));
|
||||
}
|
||||
// Perform non-vector operations for any remaining items, sum up bias term
|
||||
// and results from the vectorized code, and apply the activation function.
|
||||
output[o] = activation_function(
|
||||
std::inner_product(input.begin() + offset, input.end(),
|
||||
weights.begin() + o * input_size + offset,
|
||||
bias[o] + v[0] + v[1] + v[2] + v[3]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
using ::rnnoise::kOutputDenseBias;
|
||||
using ::rnnoise::kOutputDenseWeights;
|
||||
using ::rnnoise::kOutputLayerOutputSize;
|
||||
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
|
||||
|
||||
} // namespace
|
||||
|
||||
FullyConnectedLayer::FullyConnectedLayer(
|
||||
const size_t input_size,
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
Optimization optimization)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetScaledParams(bias)),
|
||||
weights_(GetPreprocessedFcWeights(weights, output_size)),
|
||||
activation_function_(activation_function),
|
||||
optimization_(optimization) {
|
||||
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
|
||||
<< "Static over-allocation of fully-connected layers output vectors is "
|
||||
"not sufficient.";
|
||||
RTC_DCHECK_EQ(output_size_, bias_.size())
|
||||
<< "Mismatching output size and bias terms array size.";
|
||||
RTC_DCHECK_EQ(input_size_ * output_size_, weights_.size())
|
||||
<< "Mismatching input-output size and weight coefficients array size.";
|
||||
}
|
||||
|
||||
FullyConnectedLayer::~FullyConnectedLayer() = default;
|
||||
|
||||
rtc::ArrayView<const float> FullyConnectedLayer::GetOutput() const {
|
||||
return rtc::ArrayView<const float>(output_.data(), output_size_);
|
||||
}
|
||||
|
||||
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input,
|
||||
bias_, weights_,
|
||||
activation_function_, output_);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
|
||||
weights_, activation_function_, output_);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
|
||||
weights_, activation_function_, output_);
|
||||
}
|
||||
}
|
||||
|
||||
GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
const size_t input_size,
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetPreprocessedGruTensor(bias, output_size)),
|
||||
weights_(GetPreprocessedGruTensor(weights, output_size)),
|
||||
recurrent_weights_(
|
||||
GetPreprocessedGruTensor(recurrent_weights, output_size)),
|
||||
optimization_(optimization) {
|
||||
RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
|
||||
<< "Static over-allocation of recurrent layers state vectors is not "
|
||||
"sufficient.";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
|
||||
<< "Mismatching output size and bias terms array size.";
|
||||
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
|
||||
<< "Mismatching input-output size and weight coefficients array size.";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
|
||||
recurrent_weights_.size())
|
||||
<< "Mismatching input-output size and recurrent weight coefficients array"
|
||||
" size.";
|
||||
Reset();
|
||||
}
|
||||
|
||||
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
|
||||
|
||||
rtc::ArrayView<const float> GatedRecurrentLayer::GetOutput() const {
|
||||
return rtc::ArrayView<const float>(state_.data(), output_size_);
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::Reset() {
|
||||
state_.fill(0.f);
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
}
|
||||
}
|
||||
|
||||
RnnBasedVad::RnnBasedVad()
|
||||
: input_layer_(kInputLayerInputSize,
|
||||
kInputLayerOutputSize,
|
||||
kInputDenseBias,
|
||||
kInputDenseWeights,
|
||||
TansigApproximated,
|
||||
DetectOptimization()),
|
||||
hidden_layer_(kInputLayerOutputSize,
|
||||
kHiddenLayerOutputSize,
|
||||
kHiddenGruBias,
|
||||
kHiddenGruWeights,
|
||||
kHiddenGruRecurrentWeights,
|
||||
DetectOptimization()),
|
||||
output_layer_(kHiddenLayerOutputSize,
|
||||
kOutputLayerOutputSize,
|
||||
kOutputDenseBias,
|
||||
kOutputDenseWeights,
|
||||
SigmoidApproximated,
|
||||
DetectOptimization()) {
|
||||
RnnVad::RnnVad(const AvailableCpuFeatures& cpu_features)
|
||||
: input_(kInputLayerInputSize,
|
||||
kInputLayerOutputSize,
|
||||
kInputDenseBias,
|
||||
kInputDenseWeights,
|
||||
ActivationFunction::kTansigApproximated,
|
||||
cpu_features,
|
||||
/*layer_name=*/"FC1"),
|
||||
hidden_(kInputLayerOutputSize,
|
||||
kHiddenLayerOutputSize,
|
||||
kHiddenGruBias,
|
||||
kHiddenGruWeights,
|
||||
kHiddenGruRecurrentWeights,
|
||||
cpu_features,
|
||||
/*layer_name=*/"GRU1"),
|
||||
output_(kHiddenLayerOutputSize,
|
||||
kOutputLayerOutputSize,
|
||||
kOutputDenseBias,
|
||||
kOutputDenseWeights,
|
||||
ActivationFunction::kSigmoidApproximated,
|
||||
// The output layer is just 24x1. The unoptimized code is faster.
|
||||
NoAvailableCpuFeatures(),
|
||||
/*layer_name=*/"FC2") {
|
||||
// Input-output chaining size checks.
|
||||
RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size())
|
||||
RTC_DCHECK_EQ(input_.size(), hidden_.input_size())
|
||||
<< "The input and the hidden layers sizes do not match.";
|
||||
RTC_DCHECK_EQ(hidden_layer_.output_size(), output_layer_.input_size())
|
||||
RTC_DCHECK_EQ(hidden_.size(), output_.input_size())
|
||||
<< "The hidden and the output layers sizes do not match.";
|
||||
}
|
||||
|
||||
RnnBasedVad::~RnnBasedVad() = default;
|
||||
RnnVad::~RnnVad() = default;
|
||||
|
||||
void RnnBasedVad::Reset() {
|
||||
hidden_layer_.Reset();
|
||||
void RnnVad::Reset() {
|
||||
hidden_.Reset();
|
||||
}
|
||||
|
||||
float RnnBasedVad::ComputeVadProbability(
|
||||
float RnnVad::ComputeVadProbability(
|
||||
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
|
||||
bool is_silence) {
|
||||
if (is_silence) {
|
||||
Reset();
|
||||
return 0.f;
|
||||
}
|
||||
input_layer_.ComputeOutput(feature_vector);
|
||||
hidden_layer_.ComputeOutput(input_layer_.GetOutput());
|
||||
output_layer_.ComputeOutput(hidden_layer_.GetOutput());
|
||||
const auto vad_output = output_layer_.GetOutput();
|
||||
return vad_output[0];
|
||||
input_.ComputeOutput(feature_vector);
|
||||
hidden_.ComputeOutput(input_);
|
||||
output_.ComputeOutput(hidden_);
|
||||
RTC_DCHECK_EQ(output_.size(), 1);
|
||||
return output_.data()[0];
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
|
@ -18,106 +18,33 @@
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "api/function_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Maximum number of units for a fully-connected layer. This value is used to
|
||||
// over-allocate space for fully-connected layers output vectors (implemented as
|
||||
// std::array). The value should equal the number of units of the largest
|
||||
// fully-connected layer.
|
||||
constexpr size_t kFullyConnectedLayersMaxUnits = 24;
|
||||
|
||||
// Maximum number of units for a recurrent layer. This value is used to
|
||||
// over-allocate space for recurrent layers state vectors (implemented as
|
||||
// std::array). The value should equal the number of units of the largest
|
||||
// recurrent layer.
|
||||
constexpr size_t kRecurrentLayersMaxUnits = 24;
|
||||
|
||||
// Fully-connected layer.
|
||||
class FullyConnectedLayer {
|
||||
// Recurrent network with hard-coded architecture and weights for voice activity
|
||||
// detection.
|
||||
class RnnVad {
|
||||
public:
|
||||
FullyConnectedLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
Optimization optimization);
|
||||
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
||||
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
||||
~FullyConnectedLayer();
|
||||
size_t input_size() const { return input_size_; }
|
||||
size_t output_size() const { return output_size_; }
|
||||
Optimization optimization() const { return optimization_; }
|
||||
rtc::ArrayView<const float> GetOutput() const;
|
||||
// Computes the fully-connected layer output.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
rtc::FunctionView<float(float)> activation_function_;
|
||||
// The output vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, for efficiency, over-allocation is used.
|
||||
std::array<float, kFullyConnectedLayersMaxUnits> output_;
|
||||
const Optimization optimization_;
|
||||
};
|
||||
|
||||
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
|
||||
// activation functions for the update/reset and output gates respectively.
|
||||
class GatedRecurrentLayer {
|
||||
public:
|
||||
GatedRecurrentLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization);
|
||||
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
|
||||
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
|
||||
~GatedRecurrentLayer();
|
||||
size_t input_size() const { return input_size_; }
|
||||
size_t output_size() const { return output_size_; }
|
||||
Optimization optimization() const { return optimization_; }
|
||||
rtc::ArrayView<const float> GetOutput() const;
|
||||
explicit RnnVad(const AvailableCpuFeatures& cpu_features);
|
||||
RnnVad(const RnnVad&) = delete;
|
||||
RnnVad& operator=(const RnnVad&) = delete;
|
||||
~RnnVad();
|
||||
void Reset();
|
||||
// Computes the recurrent layer output and updates the status.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
const std::vector<float> recurrent_weights_;
|
||||
// The state vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, to avoid dynamic allocation, over-allocation is used.
|
||||
std::array<float, kRecurrentLayersMaxUnits> state_;
|
||||
const Optimization optimization_;
|
||||
};
|
||||
|
||||
// Recurrent network based VAD.
|
||||
class RnnBasedVad {
|
||||
public:
|
||||
RnnBasedVad();
|
||||
RnnBasedVad(const RnnBasedVad&) = delete;
|
||||
RnnBasedVad& operator=(const RnnBasedVad&) = delete;
|
||||
~RnnBasedVad();
|
||||
void Reset();
|
||||
// Compute and returns the probability of voice (range: [0.0, 1.0]).
|
||||
// Observes `feature_vector` and `is_silence`, updates the RNN and returns the
|
||||
// current voice probability.
|
||||
float ComputeVadProbability(
|
||||
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
|
||||
bool is_silence);
|
||||
|
||||
private:
|
||||
FullyConnectedLayer input_layer_;
|
||||
GatedRecurrentLayer hidden_layer_;
|
||||
FullyConnectedLayer output_layer_;
|
||||
FullyConnectedLayer input_;
|
||||
GatedRecurrentLayer hidden_;
|
||||
FullyConnectedLayer output_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
|
104
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_fc.cc
Normal file
104
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_fc.cc
Normal file
@ -0,0 +1,104 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
|
||||
std::vector<float> scaled_params(params.size());
|
||||
std::transform(params.begin(), params.end(), scaled_params.begin(),
|
||||
[](int8_t x) -> float {
|
||||
return ::rnnoise::kWeightsScale * static_cast<float>(x);
|
||||
});
|
||||
return scaled_params;
|
||||
}
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this
|
||||
// function to improve setup time.
|
||||
// Casts and scales `weights` and re-arranges the layout.
|
||||
std::vector<float> PreprocessWeights(rtc::ArrayView<const int8_t> weights,
|
||||
int output_size) {
|
||||
if (output_size == 1) {
|
||||
return GetScaledParams(weights);
|
||||
}
|
||||
// Transpose, scale and cast.
|
||||
const int input_size = rtc::CheckedDivExact(
|
||||
rtc::dchecked_cast<int>(weights.size()), output_size);
|
||||
std::vector<float> w(weights.size());
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
for (int i = 0; i < input_size; ++i) {
|
||||
w[o * input_size + i] = rnnoise::kWeightsScale *
|
||||
static_cast<float>(weights[i * output_size + o]);
|
||||
}
|
||||
}
|
||||
return w;
|
||||
}
|
||||
|
||||
rtc::FunctionView<float(float)> GetActivationFunction(
|
||||
ActivationFunction activation_function) {
|
||||
switch (activation_function) {
|
||||
case ActivationFunction::kTansigApproximated:
|
||||
return ::rnnoise::TansigApproximated;
|
||||
case ActivationFunction::kSigmoidApproximated:
|
||||
return ::rnnoise::SigmoidApproximated;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
FullyConnectedLayer::FullyConnectedLayer(
|
||||
const int input_size,
|
||||
const int output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
ActivationFunction activation_function,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
absl::string_view layer_name)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetScaledParams(bias)),
|
||||
weights_(PreprocessWeights(weights, output_size)),
|
||||
vector_math_(cpu_features),
|
||||
activation_function_(GetActivationFunction(activation_function)) {
|
||||
RTC_DCHECK_LE(output_size_, kFullyConnectedLayerMaxUnits)
|
||||
<< "Insufficient FC layer over-allocation (" << layer_name << ").";
|
||||
RTC_DCHECK_EQ(output_size_, bias_.size())
|
||||
<< "Mismatching output size and bias terms array size (" << layer_name
|
||||
<< ").";
|
||||
RTC_DCHECK_EQ(input_size_ * output_size_, weights_.size())
|
||||
<< "Mismatching input-output size and weight coefficients array size ("
|
||||
<< layer_name << ").";
|
||||
}
|
||||
|
||||
FullyConnectedLayer::~FullyConnectedLayer() = default;
|
||||
|
||||
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size_);
|
||||
rtc::ArrayView<const float> weights(weights_);
|
||||
for (int o = 0; o < output_size_; ++o) {
|
||||
output_[o] = activation_function_(
|
||||
bias_[o] + vector_math_.DotProduct(
|
||||
input, weights.subview(o * input_size_, input_size_)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
72
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_fc.h
Normal file
72
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_fc.h
Normal file
@ -0,0 +1,72 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "api/array_view.h"
|
||||
#include "api/function_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Activation function for a neural network cell.
|
||||
enum class ActivationFunction { kTansigApproximated, kSigmoidApproximated };
|
||||
|
||||
// Maximum number of units for an FC layer.
|
||||
constexpr int kFullyConnectedLayerMaxUnits = 24;
|
||||
|
||||
// Fully-connected layer with a custom activation function which owns the output
|
||||
// buffer.
|
||||
class FullyConnectedLayer {
|
||||
public:
|
||||
// Ctor. `output_size` cannot be greater than `kFullyConnectedLayerMaxUnits`.
|
||||
FullyConnectedLayer(int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
ActivationFunction activation_function,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
absl::string_view layer_name);
|
||||
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
||||
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
||||
~FullyConnectedLayer();
|
||||
|
||||
// Returns the size of the input vector.
|
||||
int input_size() const { return input_size_; }
|
||||
// Returns the pointer to the first element of the output buffer.
|
||||
const float* data() const { return output_.data(); }
|
||||
// Returns the size of the output buffer.
|
||||
int size() const { return output_size_; }
|
||||
|
||||
// Computes the fully-connected layer output.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const int input_size_;
|
||||
const int output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
const VectorMath vector_math_;
|
||||
rtc::FunctionView<float(float)> activation_function_;
|
||||
// Over-allocated array with size equal to `output_size_`.
|
||||
std::array<float, kFullyConnectedLayerMaxUnits> output_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
|
198
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
Normal file
198
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
Normal file
@ -0,0 +1,198 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
constexpr int kNumGruGates = 3; // Update, reset, output.
|
||||
|
||||
std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
|
||||
int output_size) {
|
||||
// Transpose, cast and scale.
|
||||
// `n` is the size of the first dimension of the 3-dim tensor `weights`.
|
||||
const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
|
||||
output_size * kNumGruGates);
|
||||
const int stride_src = kNumGruGates * output_size;
|
||||
const int stride_dst = n * output_size;
|
||||
std::vector<float> tensor_dst(tensor_src.size());
|
||||
for (int g = 0; g < kNumGruGates; ++g) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
tensor_dst[g * stride_dst + o * n + i] =
|
||||
::rnnoise::kWeightsScale *
|
||||
static_cast<float>(
|
||||
tensor_src[i * stride_src + g * output_size + o]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tensor_dst;
|
||||
}
|
||||
|
||||
// Computes the output for the update or the reset gate.
|
||||
// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
|
||||
// - `g`: output gate vector
|
||||
// - `W`: weights matrix
|
||||
// - `i`: input vector
|
||||
// - `R`: recurrent weights matrix
|
||||
// - `s`: state gate vector
|
||||
// - `b`: bias vector
|
||||
void ComputeUpdateResetGate(int input_size,
|
||||
int output_size,
|
||||
const VectorMath& vector_math,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> state,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<float> gate) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_EQ(state.size(), output_size);
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
|
||||
RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated.
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
float x = bias[o];
|
||||
x += vector_math.DotProduct(input,
|
||||
weights.subview(o * input_size, input_size));
|
||||
x += vector_math.DotProduct(
|
||||
state, recurrent_weights.subview(o * output_size, output_size));
|
||||
gate[o] = ::rnnoise::SigmoidApproximated(x);
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the output for the state gate.
|
||||
// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
|
||||
// - `s'`: output state gate vector
|
||||
// - `s`: previous state gate vector
|
||||
// - `u`: update gate vector
|
||||
// - `W`: weights matrix
|
||||
// - `i`: input vector
|
||||
// - `R`: recurrent weights matrix
|
||||
// - `r`: reset gate vector
|
||||
// - `b`: bias vector
|
||||
// - `.*` element-wise product
|
||||
void ComputeStateGate(int input_size,
|
||||
int output_size,
|
||||
const VectorMath& vector_math,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> update,
|
||||
rtc::ArrayView<const float> reset,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<float> state) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated.
|
||||
RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated.
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
|
||||
RTC_DCHECK_EQ(state.size(), output_size);
|
||||
std::array<float, kGruLayerMaxUnits> reset_x_state;
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
reset_x_state[o] = state[o] * reset[o];
|
||||
}
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
float x = bias[o];
|
||||
x += vector_math.DotProduct(input,
|
||||
weights.subview(o * input_size, input_size));
|
||||
x += vector_math.DotProduct(
|
||||
{reset_x_state.data(), static_cast<size_t>(output_size)},
|
||||
recurrent_weights.subview(o * output_size, output_size));
|
||||
state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
const int input_size,
|
||||
const int output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
absl::string_view layer_name)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(PreprocessGruTensor(bias, output_size)),
|
||||
weights_(PreprocessGruTensor(weights, output_size)),
|
||||
recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
|
||||
vector_math_(cpu_features) {
|
||||
RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
|
||||
<< "Insufficient GRU layer over-allocation (" << layer_name << ").";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
|
||||
<< "Mismatching output size and bias terms array size (" << layer_name
|
||||
<< ").";
|
||||
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
|
||||
<< "Mismatching input-output size and weight coefficients array size ("
|
||||
<< layer_name << ").";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
|
||||
recurrent_weights_.size())
|
||||
<< "Mismatching input-output size and recurrent weight coefficients array"
|
||||
" size ("
|
||||
<< layer_name << ").";
|
||||
Reset();
|
||||
}
|
||||
|
||||
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
|
||||
|
||||
void GatedRecurrentLayer::Reset() {
|
||||
state_.fill(0.f);
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size_);
|
||||
|
||||
// The tensors below are organized as a sequence of flattened tensors for the
|
||||
// `update`, `reset` and `state` gates.
|
||||
rtc::ArrayView<const float> bias(bias_);
|
||||
rtc::ArrayView<const float> weights(weights_);
|
||||
rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
|
||||
// Strides to access to the flattened tensors for a specific gate.
|
||||
const int stride_weights = input_size_ * output_size_;
|
||||
const int stride_recurrent_weights = output_size_ * output_size_;
|
||||
|
||||
rtc::ArrayView<float> state(state_.data(), output_size_);
|
||||
|
||||
// Update gate.
|
||||
std::array<float, kGruLayerMaxUnits> update;
|
||||
ComputeUpdateResetGate(
|
||||
input_size_, output_size_, vector_math_, input, state,
|
||||
bias.subview(0, output_size_), weights.subview(0, stride_weights),
|
||||
recurrent_weights.subview(0, stride_recurrent_weights), update);
|
||||
// Reset gate.
|
||||
std::array<float, kGruLayerMaxUnits> reset;
|
||||
ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
|
||||
bias.subview(output_size_, output_size_),
|
||||
weights.subview(stride_weights, stride_weights),
|
||||
recurrent_weights.subview(stride_recurrent_weights,
|
||||
stride_recurrent_weights),
|
||||
reset);
|
||||
// State gate.
|
||||
ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
|
||||
reset, bias.subview(2 * output_size_, output_size_),
|
||||
weights.subview(2 * stride_weights, stride_weights),
|
||||
recurrent_weights.subview(2 * stride_recurrent_weights,
|
||||
stride_recurrent_weights),
|
||||
state);
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
70
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
Normal file
70
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
Normal file
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Maximum number of units for a GRU layer.
|
||||
constexpr int kGruLayerMaxUnits = 24;
|
||||
|
||||
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
|
||||
// activation functions for the update/reset and output gates respectively.
|
||||
class GatedRecurrentLayer {
|
||||
public:
|
||||
// Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
|
||||
GatedRecurrentLayer(int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
absl::string_view layer_name);
|
||||
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
|
||||
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
|
||||
~GatedRecurrentLayer();
|
||||
|
||||
// Returns the size of the input vector.
|
||||
int input_size() const { return input_size_; }
|
||||
// Returns the pointer to the first element of the output buffer.
|
||||
const float* data() const { return state_.data(); }
|
||||
// Returns the size of the output buffer.
|
||||
int size() const { return output_size_; }
|
||||
|
||||
// Resets the GRU state.
|
||||
void Reset();
|
||||
// Computes the recurrent layer output and updates the status.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const int input_size_;
|
||||
const int output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
const std::vector<float> recurrent_weights_;
|
||||
const VectorMath vector_math_;
|
||||
// Over-allocated array with size equal to `output_size_`.
|
||||
std::array<float, kGruLayerMaxUnits> state_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
|
@ -29,7 +29,7 @@ namespace rnn_vad {
|
||||
// values are written at the end of the buffer.
|
||||
// The class also provides a view on the most recent M values, where 0 < M <= S
|
||||
// and by default M = N.
|
||||
template <typename T, size_t S, size_t N, size_t M = N>
|
||||
template <typename T, int S, int N, int M = N>
|
||||
class SequenceBuffer {
|
||||
static_assert(N <= S,
|
||||
"The new chunk size cannot be larger than the sequence buffer "
|
||||
@ -45,8 +45,8 @@ class SequenceBuffer {
|
||||
SequenceBuffer(const SequenceBuffer&) = delete;
|
||||
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
|
||||
~SequenceBuffer() = default;
|
||||
size_t size() const { return S; }
|
||||
size_t chunks_size() const { return N; }
|
||||
int size() const { return S; }
|
||||
int chunks_size() const { return N; }
|
||||
// Sets the sequence buffer values to zero.
|
||||
void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); }
|
||||
// Returns a view on the whole buffer.
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -32,11 +33,11 @@ void UpdateCepstralDifferenceStats(
|
||||
RTC_DCHECK(sym_matrix_buf);
|
||||
// Compute the new cepstral distance stats.
|
||||
std::array<float, kCepstralCoeffsHistorySize - 1> distances;
|
||||
for (size_t i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
|
||||
const size_t delay = i + 1;
|
||||
for (int i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
|
||||
const int delay = i + 1;
|
||||
auto old_cepstral_coeffs = ring_buf.GetArrayView(delay);
|
||||
distances[i] = 0.f;
|
||||
for (size_t k = 0; k < kNumBands; ++k) {
|
||||
for (int k = 0; k < kNumBands; ++k) {
|
||||
const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k];
|
||||
distances[i] += c * c;
|
||||
}
|
||||
@ -48,9 +49,9 @@ void UpdateCepstralDifferenceStats(
|
||||
// Computes the first half of the Vorbis window.
|
||||
std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
|
||||
float scaling = 1.f) {
|
||||
constexpr size_t kHalfSize = kFrameSize20ms24kHz / 2;
|
||||
constexpr int kHalfSize = kFrameSize20ms24kHz / 2;
|
||||
std::array<float, kHalfSize> half_window{};
|
||||
for (size_t i = 0; i < kHalfSize; ++i) {
|
||||
for (int i = 0; i < kHalfSize; ++i) {
|
||||
half_window[i] =
|
||||
scaling *
|
||||
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
|
||||
@ -71,8 +72,8 @@ void ComputeWindowedForwardFft(
|
||||
RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
|
||||
// Apply windowing.
|
||||
auto in = fft_input_buffer->GetView();
|
||||
for (size_t i = 0, j = kFrameSize20ms24kHz - 1; i < half_window.size();
|
||||
++i, --j) {
|
||||
for (int i = 0, j = kFrameSize20ms24kHz - 1;
|
||||
rtc::SafeLt(i, half_window.size()); ++i, --j) {
|
||||
in[i] = frame[i] * half_window[i];
|
||||
in[j] = frame[j] * half_window[i];
|
||||
}
|
||||
@ -162,7 +163,7 @@ void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
|
||||
RTC_DCHECK_EQ(average.size(), first_derivative.size());
|
||||
RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
|
||||
RTC_DCHECK_LE(average.size(), curr.size());
|
||||
for (size_t i = 0; i < average.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, average.size()); ++i) {
|
||||
// Average, kernel: [1, 1, 1].
|
||||
average[i] = curr[i] + prev1[i] + prev2[i];
|
||||
// First derivative, kernel: [1, 0, - 1].
|
||||
@ -178,7 +179,7 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
|
||||
reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
|
||||
bands_cross_corr_);
|
||||
// Normalize.
|
||||
for (size_t i = 0; i < bands_cross_corr_.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, bands_cross_corr_.size()); ++i) {
|
||||
bands_cross_corr_[i] =
|
||||
bands_cross_corr_[i] /
|
||||
std::sqrt(0.001f + reference_frame_bands_energy_[i] *
|
||||
@ -194,9 +195,9 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
|
||||
float SpectralFeaturesExtractor::ComputeVariability() const {
|
||||
// Compute cepstral variability score.
|
||||
float variability = 0.f;
|
||||
for (size_t delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
|
||||
for (int delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
|
||||
float min_dist = std::numeric_limits<float>::max();
|
||||
for (size_t delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
|
||||
for (int delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
|
||||
if (delay1 == delay2) // The distance would be 0.
|
||||
continue;
|
||||
min_dist =
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <cstddef>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -22,7 +23,7 @@ namespace {
|
||||
|
||||
// Weights for each FFT coefficient for each Opus band (Nyquist frequency
|
||||
// excluded). The size of each band is specified in
|
||||
// |kOpusScaleNumBins24kHz20ms|.
|
||||
// `kOpusScaleNumBins24kHz20ms`.
|
||||
constexpr std::array<float, kFrameSize20ms24kHz / 2> kOpusBandWeights24kHz20ms =
|
||||
{{
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 0
|
||||
@ -105,9 +106,9 @@ void SpectralCorrelator::ComputeCrossCorrelation(
|
||||
RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
|
||||
RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
|
||||
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
|
||||
size_t k = 0; // Next Fourier coefficient index.
|
||||
int k = 0; // Next Fourier coefficient index.
|
||||
cross_corr[0] = 0.f;
|
||||
for (size_t i = 0; i < kOpusBands24kHz - 1; ++i) {
|
||||
for (int i = 0; i < kOpusBands24kHz - 1; ++i) {
|
||||
cross_corr[i + 1] = 0.f;
|
||||
for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
|
||||
const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
|
||||
@ -137,11 +138,11 @@ void ComputeSmoothedLogMagnitudeSpectrum(
|
||||
return x;
|
||||
};
|
||||
// Smoothing over the bands for which the band energy is defined.
|
||||
for (size_t i = 0; i < bands_energy.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, bands_energy.size()); ++i) {
|
||||
log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
|
||||
}
|
||||
// Smoothing over the remaining bands (zero energy).
|
||||
for (size_t i = bands_energy.size(); i < kNumBands; ++i) {
|
||||
for (int i = bands_energy.size(); i < kNumBands; ++i) {
|
||||
log_bands_energy[i] = smooth(kLogOneByHundred);
|
||||
}
|
||||
}
|
||||
@ -149,8 +150,8 @@ void ComputeSmoothedLogMagnitudeSpectrum(
|
||||
std::array<float, kNumBands * kNumBands> ComputeDctTable() {
|
||||
std::array<float, kNumBands * kNumBands> dct_table;
|
||||
const double k = std::sqrt(0.5);
|
||||
for (size_t i = 0; i < kNumBands; ++i) {
|
||||
for (size_t j = 0; j < kNumBands; ++j)
|
||||
for (int i = 0; i < kNumBands; ++i) {
|
||||
for (int j = 0; j < kNumBands; ++j)
|
||||
dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
|
||||
dct_table[i * kNumBands] *= k;
|
||||
}
|
||||
@ -173,9 +174,9 @@ void ComputeDct(rtc::ArrayView<const float> in,
|
||||
RTC_DCHECK_LE(in.size(), kNumBands);
|
||||
RTC_DCHECK_LE(1, out.size());
|
||||
RTC_DCHECK_LE(out.size(), in.size());
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, out.size()); ++i) {
|
||||
out[i] = 0.f;
|
||||
for (size_t j = 0; j < in.size(); ++j) {
|
||||
for (int j = 0; rtc::SafeLt(j, in.size()); ++j) {
|
||||
out[i] += in[j] * dct_table[j * kNumBands + i];
|
||||
}
|
||||
// TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.
|
||||
|
@ -25,7 +25,7 @@ namespace rnn_vad {
|
||||
// At a sample rate of 24 kHz, the last 3 Opus bands are beyond the Nyquist
|
||||
// frequency. However, band #19 gets the contributions from band #18 because
|
||||
// of the symmetric triangular filter with peak response at 12 kHz.
|
||||
constexpr size_t kOpusBands24kHz = 20;
|
||||
constexpr int kOpusBands24kHz = 20;
|
||||
static_assert(kOpusBands24kHz < kNumBands,
|
||||
"The number of bands at 24 kHz must be less than those defined "
|
||||
"in the Opus scale at 48 kHz.");
|
||||
@ -50,8 +50,8 @@ class SpectralCorrelator {
|
||||
~SpectralCorrelator();
|
||||
|
||||
// Computes the band-wise spectral auto-correlations.
|
||||
// |x| must:
|
||||
// - have size equal to |kFrameSize20ms24kHz|;
|
||||
// `x` must:
|
||||
// - have size equal to `kFrameSize20ms24kHz`;
|
||||
// - be encoded as vectors of interleaved real-complex FFT coefficients
|
||||
// where x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
|
||||
void ComputeAutoCorrelation(
|
||||
@ -59,8 +59,8 @@ class SpectralCorrelator {
|
||||
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const;
|
||||
|
||||
// Computes the band-wise spectral cross-correlations.
|
||||
// |x| and |y| must:
|
||||
// - have size equal to |kFrameSize20ms24kHz|;
|
||||
// `x` and `y` must:
|
||||
// - have size equal to `kFrameSize20ms24kHz`;
|
||||
// - be encoded as vectors of interleaved real-complex FFT coefficients where
|
||||
// x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
|
||||
void ComputeCrossCorrelation(
|
||||
@ -82,12 +82,12 @@ void ComputeSmoothedLogMagnitudeSpectrum(
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
|
||||
// spectral_features.cc. Creates a DCT table for arrays having size equal to
|
||||
// |kNumBands|. Declared here for unit testing.
|
||||
// `kNumBands`. Declared here for unit testing.
|
||||
std::array<float, kNumBands * kNumBands> ComputeDctTable();
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
|
||||
// spectral_features.cc. Computes DCT for |in| given a pre-computed DCT table.
|
||||
// In-place computation is not allowed and |out| can be smaller than |in| in
|
||||
// spectral_features.cc. Computes DCT for `in` given a pre-computed DCT table.
|
||||
// In-place computation is not allowed and `out` can be smaller than `in` in
|
||||
// order to only compute the first DCT coefficients. Declared here for unit
|
||||
// testing.
|
||||
void ComputeDct(rtc::ArrayView<const float> in,
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -29,7 +30,7 @@ namespace rnn_vad {
|
||||
// removed when one of the two corresponding items that have been compared is
|
||||
// removed from the ring buffer. It is assumed that the comparison is symmetric
|
||||
// and that comparing an item with itself is not needed.
|
||||
template <typename T, size_t S>
|
||||
template <typename T, int S>
|
||||
class SymmetricMatrixBuffer {
|
||||
static_assert(S > 2, "");
|
||||
|
||||
@ -45,9 +46,9 @@ class SymmetricMatrixBuffer {
|
||||
buf_.fill(0);
|
||||
}
|
||||
// Pushes the results from the comparison between the most recent item and
|
||||
// those that are still in the ring buffer. The first element in |values| must
|
||||
// those that are still in the ring buffer. The first element in `values` must
|
||||
// correspond to the comparison between the most recent item and the second
|
||||
// most recent one in the ring buffer, whereas the last element in |values|
|
||||
// most recent one in the ring buffer, whereas the last element in `values`
|
||||
// must correspond to the comparison between the most recent item and the
|
||||
// oldest one in the ring buffer.
|
||||
void Push(rtc::ArrayView<T, S - 1> values) {
|
||||
@ -55,19 +56,19 @@ class SymmetricMatrixBuffer {
|
||||
// column left.
|
||||
std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T));
|
||||
// Copy new values in the last column in the right order.
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
const size_t index = (S - 1 - i) * (S - 1) - 1;
|
||||
RTC_DCHECK_LE(static_cast<size_t>(0), index);
|
||||
for (int i = 0; rtc::SafeLt(i, values.size()); ++i) {
|
||||
const int index = (S - 1 - i) * (S - 1) - 1;
|
||||
RTC_DCHECK_GE(index, 0);
|
||||
RTC_DCHECK_LT(index, buf_.size());
|
||||
buf_[index] = values[i];
|
||||
}
|
||||
}
|
||||
// Reads the value that corresponds to comparison of two items in the ring
|
||||
// buffer having delay |delay1| and |delay2|. The two arguments must not be
|
||||
// buffer having delay `delay1` and `delay2`. The two arguments must not be
|
||||
// equal and both must be in {0, ..., S - 1}.
|
||||
T GetValue(size_t delay1, size_t delay2) const {
|
||||
int row = S - 1 - static_cast<int>(delay1);
|
||||
int col = S - 1 - static_cast<int>(delay2);
|
||||
T GetValue(int delay1, int delay2) const {
|
||||
int row = S - 1 - delay1;
|
||||
int col = S - 1 - delay2;
|
||||
RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed.";
|
||||
if (row > col)
|
||||
std::swap(row, col); // Swap to access the upper-right triangular part.
|
||||
|
@ -10,21 +10,61 @@
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "test/gtest.h"
|
||||
#include "test/testsupport/file_utils.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
using ReaderPairType =
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>;
|
||||
// File reader for binary files that contain a sequence of values with
|
||||
// arithmetic type `T`. The values of type `T` that are read are cast to float.
|
||||
template <typename T>
|
||||
class FloatFileReader : public FileReader {
|
||||
public:
|
||||
static_assert(std::is_arithmetic<T>::value, "");
|
||||
explicit FloatFileReader(absl::string_view filename)
|
||||
: is_(std::string(filename), std::ios::binary | std::ios::ate),
|
||||
size_(is_.tellg() / sizeof(T)) {
|
||||
RTC_CHECK(is_);
|
||||
SeekBeginning();
|
||||
}
|
||||
FloatFileReader(const FloatFileReader&) = delete;
|
||||
FloatFileReader& operator=(const FloatFileReader&) = delete;
|
||||
~FloatFileReader() = default;
|
||||
|
||||
int size() const override { return size_; }
|
||||
bool ReadChunk(rtc::ArrayView<float> dst) override {
|
||||
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
|
||||
if (std::is_same<T, float>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
|
||||
} else {
|
||||
buffer_.resize(dst.size());
|
||||
is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
|
||||
std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
|
||||
[](const T& v) -> float { return static_cast<float>(v); });
|
||||
}
|
||||
return is_.gcount() == bytes_to_read;
|
||||
}
|
||||
bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
|
||||
void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); }
|
||||
void SeekBeginning() override { is_.seekg(0, is_.beg); }
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
const int size_;
|
||||
std::vector<T> buffer_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -33,7 +73,7 @@ using webrtc::test::ResourcePath;
|
||||
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed) {
|
||||
ASSERT_EQ(expected.size(), computed.size());
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_FLOAT_EQ(expected[i], computed[i]);
|
||||
}
|
||||
@ -43,87 +83,61 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed,
|
||||
float tolerance) {
|
||||
ASSERT_EQ(expected.size(), computed.size());
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_NEAR(expected[i], computed[i], tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
|
||||
CreatePcmSamplesReader(const size_t frame_length) {
|
||||
auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
|
||||
test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
|
||||
frame_length);
|
||||
// The last incomplete frame is ignored.
|
||||
return {std::move(ptr), ptr->data_length() / frame_length};
|
||||
std::unique_ptr<FileReader> CreatePcmSamplesReader() {
|
||||
return std::make_unique<FloatFileReader<int16_t>>(
|
||||
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
|
||||
"pcm"));
|
||||
}
|
||||
|
||||
ReaderPairType CreatePitchBuffer24kHzReader() {
|
||||
constexpr size_t cols = 864;
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
|
||||
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
|
||||
ChunksFileReader CreatePitchBuffer24kHzReader() {
|
||||
auto reader = std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath(
|
||||
"audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
|
||||
const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz);
|
||||
return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)};
|
||||
}
|
||||
|
||||
ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
|
||||
constexpr size_t num_lp_residual_coeffs = 864;
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
|
||||
num_lp_residual_coeffs);
|
||||
return {std::move(ptr),
|
||||
rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
|
||||
ChunksFileReader CreateLpResidualAndPitchInfoReader() {
|
||||
constexpr int kPitchInfoSize = 2; // Pitch period and strength.
|
||||
constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
|
||||
auto reader = std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath(
|
||||
"audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
|
||||
const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize);
|
||||
return {kChunkSize, num_chunks, std::move(reader)};
|
||||
}
|
||||
|
||||
ReaderPairType CreateVadProbsReader() {
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
|
||||
return {std::move(ptr), ptr->data_length()};
|
||||
std::unique_ptr<FileReader> CreateGruInputReader() {
|
||||
return std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in",
|
||||
"dat"));
|
||||
}
|
||||
|
||||
std::unique_ptr<FileReader> CreateVadProbsReader() {
|
||||
return std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
|
||||
"dat"));
|
||||
}
|
||||
|
||||
PitchTestData::PitchTestData() {
|
||||
BinaryFileReader<float> test_data_reader(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
|
||||
static_cast<size_t>(1396));
|
||||
test_data_reader.ReadChunk(test_data_);
|
||||
FloatFileReader<float> reader(
|
||||
/*filename=*/ResourcePath(
|
||||
"audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
|
||||
reader.ReadChunk(pitch_buffer_24k_);
|
||||
reader.ReadChunk(square_energies_24k_);
|
||||
reader.ReadChunk(auto_correlation_12k_);
|
||||
// Reverse the order of the squared energy values.
|
||||
// Required after the WebRTC CL 191703 which switched to forward computation.
|
||||
std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
|
||||
}
|
||||
|
||||
PitchTestData::~PitchTestData() = default;
|
||||
|
||||
rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
|
||||
const {
|
||||
return {test_data_.data(), kBufSize24kHz};
|
||||
}
|
||||
|
||||
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
|
||||
PitchTestData::GetPitchBufSquareEnergiesView() const {
|
||||
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
|
||||
}
|
||||
|
||||
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
|
||||
PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
|
||||
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
|
||||
kNumPitchBufAutoCorrCoeffs};
|
||||
}
|
||||
|
||||
bool IsOptimizationAvailable(Optimization optimization) {
|
||||
switch (optimization) {
|
||||
case Optimization::kSse2:
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
return GetCPUInfo(kSSE2) != 0;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
case Optimization::kNeon:
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
case Optimization::kNone:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -11,23 +11,19 @@
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
|
||||
constexpr float kFloatMin = std::numeric_limits<float>::min();
|
||||
|
||||
@ -42,98 +38,51 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed,
|
||||
float tolerance);
|
||||
|
||||
// Reader for binary files consisting of an arbitrary long sequence of elements
|
||||
// having type T. It is possible to read and cast to another type D at once.
|
||||
template <typename T, typename D = T>
|
||||
class BinaryFileReader {
|
||||
// File reader interface.
|
||||
class FileReader {
|
||||
public:
|
||||
explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0)
|
||||
: is_(file_path, std::ios::binary | std::ios::ate),
|
||||
data_length_(is_.tellg() / sizeof(T)),
|
||||
chunk_size_(chunk_size) {
|
||||
RTC_CHECK(is_);
|
||||
SeekBeginning();
|
||||
buf_.resize(chunk_size_);
|
||||
}
|
||||
BinaryFileReader(const BinaryFileReader&) = delete;
|
||||
BinaryFileReader& operator=(const BinaryFileReader&) = delete;
|
||||
~BinaryFileReader() = default;
|
||||
size_t data_length() const { return data_length_; }
|
||||
bool ReadValue(D* dst) {
|
||||
if (std::is_same<T, D>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst), sizeof(T));
|
||||
} else {
|
||||
T v;
|
||||
is_.read(reinterpret_cast<char*>(&v), sizeof(T));
|
||||
*dst = static_cast<D>(v);
|
||||
}
|
||||
return is_.gcount() == sizeof(T);
|
||||
}
|
||||
// If |chunk_size| was specified in the ctor, it will check that the size of
|
||||
// |dst| equals |chunk_size|.
|
||||
bool ReadChunk(rtc::ArrayView<D> dst) {
|
||||
RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size()));
|
||||
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
|
||||
if (std::is_same<T, D>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
|
||||
} else {
|
||||
is_.read(reinterpret_cast<char*>(buf_.data()), bytes_to_read);
|
||||
std::transform(buf_.begin(), buf_.end(), dst.begin(),
|
||||
[](const T& v) -> D { return static_cast<D>(v); });
|
||||
}
|
||||
return is_.gcount() == bytes_to_read;
|
||||
}
|
||||
void SeekForward(size_t items) { is_.seekg(items * sizeof(T), is_.cur); }
|
||||
void SeekBeginning() { is_.seekg(0, is_.beg); }
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
const size_t data_length_;
|
||||
const size_t chunk_size_;
|
||||
std::vector<T> buf_;
|
||||
virtual ~FileReader() = default;
|
||||
// Number of values in the file.
|
||||
virtual int size() const = 0;
|
||||
// Reads `dst.size()` float values into `dst`, advances the internal file
|
||||
// position according to the number of read bytes and returns true if the
|
||||
// values are correctly read. If the number of remaining bytes in the file is
|
||||
// not sufficient to read `dst.size()` float values, `dst` is partially
|
||||
// modified and false is returned.
|
||||
virtual bool ReadChunk(rtc::ArrayView<float> dst) = 0;
|
||||
// Reads a single float value, advances the internal file position according
|
||||
// to the number of read bytes and returns true if the value is correctly
|
||||
// read. If the number of remaining bytes in the file is not sufficient to
|
||||
// read one float, `dst` is not modified and false is returned.
|
||||
virtual bool ReadValue(float& dst) = 0;
|
||||
// Advances the internal file position by `hop` float values.
|
||||
virtual void SeekForward(int hop) = 0;
|
||||
// Resets the internal file position to BOF.
|
||||
virtual void SeekBeginning() = 0;
|
||||
};
|
||||
|
||||
// Writer for binary files.
|
||||
template <typename T>
|
||||
class BinaryFileWriter {
|
||||
public:
|
||||
explicit BinaryFileWriter(const std::string& file_path)
|
||||
: os_(file_path, std::ios::binary) {}
|
||||
BinaryFileWriter(const BinaryFileWriter&) = delete;
|
||||
BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
|
||||
~BinaryFileWriter() = default;
|
||||
static_assert(std::is_arithmetic<T>::value, "");
|
||||
void WriteChunk(rtc::ArrayView<const T> value) {
|
||||
const std::streamsize bytes_to_write = value.size() * sizeof(T);
|
||||
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
|
||||
}
|
||||
|
||||
private:
|
||||
std::ofstream os_;
|
||||
// File reader for files that contain `num_chunks` chunks with size equal to
|
||||
// `chunk_size`.
|
||||
struct ChunksFileReader {
|
||||
const int chunk_size;
|
||||
const int num_chunks;
|
||||
std::unique_ptr<FileReader> reader;
|
||||
};
|
||||
|
||||
// Factories for resource file readers.
|
||||
// The functions below return a pair where the first item is a reader unique
|
||||
// pointer and the second the number of chunks that can be read from the file.
|
||||
// Creates a reader for the PCM samples that casts from S16 to float and reads
|
||||
// chunks with length |frame_length|.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
|
||||
CreatePcmSamplesReader(const size_t frame_length);
|
||||
// Creates a reader for the pitch buffer content at 24 kHz.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||
CreatePitchBuffer24kHzReader();
|
||||
// Creates a reader for the the LP residual coefficients and the pitch period
|
||||
// and gain values.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||
CreateLpResidualAndPitchPeriodGainReader();
|
||||
// Creates a reader for the VAD probabilities.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||
CreateVadProbsReader();
|
||||
// Creates a reader for the PCM S16 samples file.
|
||||
std::unique_ptr<FileReader> CreatePcmSamplesReader();
|
||||
|
||||
constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
|
||||
constexpr size_t kNumPitchBufSquareEnergies = 385;
|
||||
constexpr size_t kPitchTestDataSize =
|
||||
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
|
||||
// Creates a reader for the 24 kHz pitch buffer test data.
|
||||
ChunksFileReader CreatePitchBuffer24kHzReader();
|
||||
|
||||
// Creates a reader for the LP residual and pitch information test data.
|
||||
ChunksFileReader CreateLpResidualAndPitchInfoReader();
|
||||
|
||||
// Creates a reader for the sequence of GRU input vectors.
|
||||
std::unique_ptr<FileReader> CreateGruInputReader();
|
||||
|
||||
// Creates a reader for the VAD probabilities test data.
|
||||
std::unique_ptr<FileReader> CreateVadProbsReader();
|
||||
|
||||
// Class to retrieve a test pitch buffer content and the expected output for the
|
||||
// analysis steps.
|
||||
@ -141,20 +90,40 @@ class PitchTestData {
|
||||
public:
|
||||
PitchTestData();
|
||||
~PitchTestData();
|
||||
rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
|
||||
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
|
||||
GetPitchBufSquareEnergiesView() const;
|
||||
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
|
||||
GetPitchBufAutoCorrCoeffsView() const;
|
||||
rtc::ArrayView<const float, kBufSize24kHz> PitchBuffer24kHzView() const {
|
||||
return pitch_buffer_24k_;
|
||||
}
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> SquareEnergies24kHzView()
|
||||
const {
|
||||
return square_energies_24k_;
|
||||
}
|
||||
rtc::ArrayView<const float, kNumLags12kHz> AutoCorrelation12kHzView() const {
|
||||
return auto_correlation_12k_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<float, kPitchTestDataSize> test_data_;
|
||||
std::array<float, kBufSize24kHz> pitch_buffer_24k_;
|
||||
std::array<float, kRefineNumLags24kHz> square_energies_24k_;
|
||||
std::array<float, kNumLags12kHz> auto_correlation_12k_;
|
||||
};
|
||||
|
||||
// Returns true if the given optimization is available.
|
||||
bool IsOptimizationAvailable(Optimization optimization);
|
||||
// Writer for binary files.
|
||||
class FileWriter {
|
||||
public:
|
||||
explicit FileWriter(absl::string_view file_path)
|
||||
: os_(std::string(file_path), std::ios::binary) {}
|
||||
FileWriter(const FileWriter&) = delete;
|
||||
FileWriter& operator=(const FileWriter&) = delete;
|
||||
~FileWriter() = default;
|
||||
void WriteChunk(rtc::ArrayView<const float> value) {
|
||||
const std::streamsize bytes_to_write = value.size() * sizeof(float);
|
||||
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
|
||||
}
|
||||
|
||||
private:
|
||||
std::ofstream os_;
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
|
114
webrtc/modules/audio_processing/agc2/rnn_vad/vector_math.h
Normal file
114
webrtc/modules/audio_processing/agc2/rnn_vad/vector_math.h
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
|
||||
|
||||
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
#include <emmintrin.h>
|
||||
#endif
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Provides optimizations for mathematical operations having vectors as
|
||||
// operand(s).
|
||||
class VectorMath {
|
||||
public:
|
||||
explicit VectorMath(AvailableCpuFeatures cpu_features)
|
||||
: cpu_features_(cpu_features) {}
|
||||
|
||||
// Computes the dot product between two equally sized vectors.
|
||||
float DotProduct(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y) const {
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
if (cpu_features_.avx2) {
|
||||
return DotProductAvx2(x, y);
|
||||
} else if (cpu_features_.sse2) {
|
||||
__m128 accumulator = _mm_setzero_ps();
|
||||
constexpr int kBlockSizeLog2 = 2;
|
||||
constexpr int kBlockSize = 1 << kBlockSizeLog2;
|
||||
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
|
||||
<< kBlockSizeLog2;
|
||||
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
|
||||
RTC_DCHECK_LE(i + kBlockSize, x.size());
|
||||
const __m128 x_i = _mm_loadu_ps(&x[i]);
|
||||
const __m128 y_i = _mm_loadu_ps(&y[i]);
|
||||
// Multiply-add.
|
||||
const __m128 z_j = _mm_mul_ps(x_i, y_i);
|
||||
accumulator = _mm_add_ps(accumulator, z_j);
|
||||
}
|
||||
// Reduce `accumulator` by addition.
|
||||
__m128 high = _mm_movehl_ps(accumulator, accumulator);
|
||||
accumulator = _mm_add_ps(accumulator, high);
|
||||
high = _mm_shuffle_ps(accumulator, accumulator, 1);
|
||||
accumulator = _mm_add_ps(accumulator, high);
|
||||
float dot_product = _mm_cvtss_f32(accumulator);
|
||||
// Add the result for the last block if incomplete.
|
||||
for (int i = incomplete_block_index;
|
||||
i < rtc::dchecked_cast<int>(x.size()); ++i) {
|
||||
dot_product += x[i] * y[i];
|
||||
}
|
||||
return dot_product;
|
||||
}
|
||||
#elif defined(WEBRTC_HAS_NEON) && defined(WEBRTC_ARCH_ARM64)
|
||||
if (cpu_features_.neon) {
|
||||
float32x4_t accumulator = vdupq_n_f32(0.f);
|
||||
constexpr int kBlockSizeLog2 = 2;
|
||||
constexpr int kBlockSize = 1 << kBlockSizeLog2;
|
||||
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
|
||||
<< kBlockSizeLog2;
|
||||
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
|
||||
RTC_DCHECK_LE(i + kBlockSize, x.size());
|
||||
const float32x4_t x_i = vld1q_f32(&x[i]);
|
||||
const float32x4_t y_i = vld1q_f32(&y[i]);
|
||||
accumulator = vfmaq_f32(accumulator, x_i, y_i);
|
||||
}
|
||||
// Reduce `accumulator` by addition.
|
||||
const float32x2_t tmp =
|
||||
vpadd_f32(vget_low_f32(accumulator), vget_high_f32(accumulator));
|
||||
float dot_product = vget_lane_f32(vpadd_f32(tmp, vrev64_f32(tmp)), 0);
|
||||
// Add the result for the last block if incomplete.
|
||||
for (int i = incomplete_block_index;
|
||||
i < rtc::dchecked_cast<int>(x.size()); ++i) {
|
||||
dot_product += x[i] * y[i];
|
||||
}
|
||||
return dot_product;
|
||||
}
|
||||
#endif
|
||||
return std::inner_product(x.begin(), x.end(), y.begin(), 0.f);
|
||||
}
|
||||
|
||||
private:
|
||||
float DotProductAvx2(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y) const;
|
||||
|
||||
const AvailableCpuFeatures cpu_features_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
|
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
float VectorMath::DotProductAvx2(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y) const {
|
||||
RTC_DCHECK(cpu_features_.avx2);
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
__m256 accumulator = _mm256_setzero_ps();
|
||||
constexpr int kBlockSizeLog2 = 3;
|
||||
constexpr int kBlockSize = 1 << kBlockSizeLog2;
|
||||
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
|
||||
<< kBlockSizeLog2;
|
||||
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
|
||||
RTC_DCHECK_LE(i + kBlockSize, x.size());
|
||||
const __m256 x_i = _mm256_loadu_ps(&x[i]);
|
||||
const __m256 y_i = _mm256_loadu_ps(&y[i]);
|
||||
accumulator = _mm256_fmadd_ps(x_i, y_i, accumulator);
|
||||
}
|
||||
// Reduce `accumulator` by addition.
|
||||
__m128 high = _mm256_extractf128_ps(accumulator, 1);
|
||||
__m128 low = _mm256_extractf128_ps(accumulator, 0);
|
||||
low = _mm_add_ps(high, low);
|
||||
high = _mm_movehl_ps(high, low);
|
||||
low = _mm_add_ps(high, low);
|
||||
high = _mm_shuffle_ps(low, low, 1);
|
||||
low = _mm_add_ss(high, low);
|
||||
float dot_product = _mm_cvtss_f32(low);
|
||||
// Add the result for the last block if incomplete.
|
||||
for (int i = incomplete_block_index; i < rtc::dchecked_cast<int>(x.size());
|
||||
++i) {
|
||||
dot_product += x[i] * y[i];
|
||||
}
|
||||
return dot_product;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
Reference in New Issue
Block a user