Bump to WebRTC M120 release

Some API deprecation -- ExperimentalAgc and ExperimentalNs are gone.
We're continuing to carry iSAC even though it's gone upstream, but maybe
we'll want to drop that soon.
This commit is contained in:
Arun Raghavan
2023-12-12 10:42:58 -05:00
parent 9a202fb8c2
commit c6abf6cd3f
479 changed files with 20900 additions and 11996 deletions

View File

@ -14,10 +14,10 @@ rtc_source_set("transient_suppressor_api") {
rtc_library("transient_suppressor_impl") {
visibility = [
"..:optionally_built_submodule_creators",
":click_annotate",
":transient_suppression_test",
":transient_suppression_unittests",
":click_annotate",
"..:optionally_built_submodule_creators",
]
sources = [
"common.h",
@ -37,6 +37,7 @@ rtc_library("transient_suppressor_impl") {
]
deps = [
":transient_suppressor_api",
":voice_probability_delay_unit",
"../../../common_audio:common_audio",
"../../../common_audio:common_audio_c",
"../../../common_audio:fir_filter",
@ -48,43 +49,56 @@ rtc_library("transient_suppressor_impl") {
]
}
if (rtc_include_tests) {
rtc_executable("click_annotate") {
testonly = true
sources = [
"click_annotate.cc",
"file_utils.cc",
"file_utils.h",
]
deps = [
":transient_suppressor_impl",
"..:audio_processing",
"../../../rtc_base/system:file_wrapper",
"../../../system_wrappers",
]
}
rtc_library("voice_probability_delay_unit") {
sources = [
"voice_probability_delay_unit.cc",
"voice_probability_delay_unit.h",
]
deps = [ "../../../rtc_base:checks" ]
}
rtc_executable("transient_suppression_test") {
testonly = true
sources = [
"file_utils.cc",
"file_utils.h",
"transient_suppression_test.cc",
]
deps = [
":transient_suppressor_impl",
"..:audio_processing",
"../../../common_audio",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base/system:file_wrapper",
"../../../system_wrappers",
"../../../test:fileutils",
"../../../test:test_support",
"../agc:level_estimation",
"//testing/gtest",
"//third_party/abseil-cpp/absl/flags:flag",
"//third_party/abseil-cpp/absl/flags:parse",
]
if (rtc_include_tests) {
if (!build_with_chromium) {
rtc_executable("click_annotate") {
testonly = true
sources = [
"click_annotate.cc",
"file_utils.cc",
"file_utils.h",
]
deps = [
":transient_suppressor_impl",
"..:audio_processing",
"../../../rtc_base/system:file_wrapper",
"../../../system_wrappers",
]
}
rtc_executable("transient_suppression_test") {
testonly = true
sources = [
"file_utils.cc",
"file_utils.h",
"transient_suppression_test.cc",
"voice_probability_delay_unit_unittest.cc",
]
deps = [
":transient_suppressor_api",
":transient_suppressor_impl",
":voice_probability_delay_unit",
"..:audio_processing",
"../../../common_audio",
"../../../rtc_base/system:file_wrapper",
"../../../system_wrappers",
"../../../test:fileutils",
"../../../test:test_support",
"../agc:level_estimation",
"//testing/gtest",
"//third_party/abseil-cpp/absl/flags:flag",
"//third_party/abseil-cpp/absl/flags:parse",
"//third_party/abseil-cpp/absl/types:optional",
]
}
}
rtc_library("transient_suppression_unittests") {
@ -97,16 +111,23 @@ if (rtc_include_tests) {
"moving_moments_unittest.cc",
"transient_detector_unittest.cc",
"transient_suppressor_unittest.cc",
"voice_probability_delay_unit_unittest.cc",
"wpd_node_unittest.cc",
"wpd_tree_unittest.cc",
]
deps = [
":transient_suppressor_api",
":transient_suppressor_impl",
":voice_probability_delay_unit",
"../../../rtc_base:stringutils",
"../../../rtc_base/system:file_wrapper",
"../../../test:fileutils",
"../../../test:test_support",
"//testing/gtest",
]
absl_deps = [
"//third_party/abseil-cpp/absl/strings",
"//third_party/abseil-cpp/absl/types:optional",
]
}
}

View File

@ -1,107 +0,0 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <cfloat>
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <vector>
#include "modules/audio_processing/transient/file_utils.h"
#include "modules/audio_processing/transient/transient_detector.h"
#include "rtc_base/system/file_wrapper.h"
using webrtc::FileWrapper;
using webrtc::TransientDetector;
// Application to generate a RTP timing file.
// Opens the PCM file and divides the signal in frames.
// Creates a send times array, one for each step.
// Each block that contains a transient, has an infinite send time.
// The resultant array is written to a DAT file
// Returns -1 on error or |lost_packets| otherwise.
int main(int argc, char* argv[]) {
if (argc != 5) {
printf("\n%s - Application to generate a RTP timing file.\n\n", argv[0]);
printf("%s PCMfile DATfile chunkSize sampleRate\n\n", argv[0]);
printf("Opens the PCMfile with sampleRate in Hertz.\n");
printf("Creates a send times array, one for each chunkSize ");
printf("milliseconds step.\n");
printf("Each block that contains a transient, has an infinite send time. ");
printf("The resultant array is written to a DATfile.\n\n");
return 0;
}
FileWrapper pcm_file = FileWrapper::OpenReadOnly(argv[1]);
if (!pcm_file.is_open()) {
printf("\nThe %s could not be opened.\n\n", argv[1]);
return -1;
}
FileWrapper dat_file = FileWrapper::OpenWriteOnly(argv[2]);
if (!dat_file.is_open()) {
printf("\nThe %s could not be opened.\n\n", argv[2]);
return -1;
}
int chunk_size_ms = atoi(argv[3]);
if (chunk_size_ms <= 0) {
printf("\nThe chunkSize must be a positive integer\n\n");
return -1;
}
int sample_rate_hz = atoi(argv[4]);
if (sample_rate_hz <= 0) {
printf("\nThe sampleRate must be a positive integer\n\n");
return -1;
}
TransientDetector detector(sample_rate_hz);
int lost_packets = 0;
size_t audio_buffer_length = chunk_size_ms * sample_rate_hz / 1000;
std::unique_ptr<float[]> audio_buffer(new float[audio_buffer_length]);
std::vector<float> send_times;
// Read first buffer from the PCM test file.
size_t file_samples_read = ReadInt16FromFileToFloatBuffer(
&pcm_file, audio_buffer_length, audio_buffer.get());
for (int time = 0; file_samples_read > 0; time += chunk_size_ms) {
// Pad the rest of the buffer with zeros.
for (size_t i = file_samples_read; i < audio_buffer_length; ++i) {
audio_buffer[i] = 0.0;
}
float value =
detector.Detect(audio_buffer.get(), audio_buffer_length, NULL, 0);
if (value < 0.5f) {
value = time;
} else {
value = FLT_MAX;
++lost_packets;
}
send_times.push_back(value);
// Read next buffer from the PCM test file.
file_samples_read = ReadInt16FromFileToFloatBuffer(
&pcm_file, audio_buffer_length, audio_buffer.get());
}
size_t floats_written =
WriteFloatBufferToFile(&dat_file, send_times.size(), &send_times[0]);
if (floats_written == 0) {
printf("\nThe send times could not be written to DAT file\n\n");
return -1;
}
pcm_file.Close();
dat_file.Close();
return lost_packets;
}

View File

@ -18,7 +18,7 @@
namespace webrtc {
// Returns the proper length of the output buffer that you should use for the
// given |in_length| and decimation |odd_sequence|.
// given `in_length` and decimation `odd_sequence`.
// Return -1 on error.
inline size_t GetOutLengthToDyadicDecimate(size_t in_length,
bool odd_sequence) {
@ -34,10 +34,10 @@ inline size_t GetOutLengthToDyadicDecimate(size_t in_length,
// Performs a dyadic decimation: removes every odd/even member of a sequence
// halving its overall length.
// Arguments:
// in: array of |in_length|.
// in: array of `in_length`.
// odd_sequence: If false, the odd members will be removed (1, 3, 5, ...);
// if true, the even members will be removed (0, 2, 4, ...).
// out: array of |out_length|. |out_length| must be large enough to
// out: array of `out_length`. `out_length` must be large enough to
// hold the decimated output. The necessary length can be provided by
// GetOutLengthToDyadicDecimate().
// Must be previously allocated.

View File

@ -12,7 +12,6 @@
#define MODULES_AUDIO_PROCESSING_TRANSIENT_FILE_UTILS_H_
#include <string.h>
#include <cstdint>
#include "rtc_base/system/file_wrapper.h"
@ -51,63 +50,63 @@ int ConvertFloatToByteArray(float value, uint8_t out_bytes[4]);
// Returns 0 if correct, -1 on error.
int ConvertDoubleToByteArray(double value, uint8_t out_bytes[8]);
// Reads |length| 16-bit integers from |file| to |buffer|.
// |file| must be previously opened.
// Reads `length` 16-bit integers from `file` to `buffer`.
// `file` must be previously opened.
// Returns the number of 16-bit integers read or -1 on error.
size_t ReadInt16BufferFromFile(FileWrapper* file,
size_t length,
int16_t* buffer);
// Reads |length| 16-bit integers from |file| and stores those values
// (converting them) in |buffer|.
// |file| must be previously opened.
// Reads `length` 16-bit integers from `file` and stores those values
// (converting them) in `buffer`.
// `file` must be previously opened.
// Returns the number of 16-bit integers read or -1 on error.
size_t ReadInt16FromFileToFloatBuffer(FileWrapper* file,
size_t length,
float* buffer);
// Reads |length| 16-bit integers from |file| and stores those values
// (converting them) in |buffer|.
// |file| must be previously opened.
// Reads `length` 16-bit integers from `file` and stores those values
// (converting them) in `buffer`.
// `file` must be previously opened.
// Returns the number of 16-bit integers read or -1 on error.
size_t ReadInt16FromFileToDoubleBuffer(FileWrapper* file,
size_t length,
double* buffer);
// Reads |length| floats in binary representation (4 bytes) from |file| to
// |buffer|.
// |file| must be previously opened.
// Reads `length` floats in binary representation (4 bytes) from `file` to
// `buffer`.
// `file` must be previously opened.
// Returns the number of floats read or -1 on error.
size_t ReadFloatBufferFromFile(FileWrapper* file, size_t length, float* buffer);
// Reads |length| doubles in binary representation (8 bytes) from |file| to
// |buffer|.
// |file| must be previously opened.
// Reads `length` doubles in binary representation (8 bytes) from `file` to
// `buffer`.
// `file` must be previously opened.
// Returns the number of doubles read or -1 on error.
size_t ReadDoubleBufferFromFile(FileWrapper* file,
size_t length,
double* buffer);
// Writes |length| 16-bit integers from |buffer| in binary representation (2
// bytes) to |file|. It flushes |file|, so after this call there are no
// Writes `length` 16-bit integers from `buffer` in binary representation (2
// bytes) to `file`. It flushes `file`, so after this call there are no
// writings pending.
// |file| must be previously opened.
// `file` must be previously opened.
// Returns the number of doubles written or -1 on error.
size_t WriteInt16BufferToFile(FileWrapper* file,
size_t length,
const int16_t* buffer);
// Writes |length| floats from |buffer| in binary representation (4 bytes) to
// |file|. It flushes |file|, so after this call there are no writtings pending.
// |file| must be previously opened.
// Writes `length` floats from `buffer` in binary representation (4 bytes) to
// `file`. It flushes `file`, so after this call there are no writtings pending.
// `file` must be previously opened.
// Returns the number of doubles written or -1 on error.
size_t WriteFloatBufferToFile(FileWrapper* file,
size_t length,
const float* buffer);
// Writes |length| doubles from |buffer| in binary representation (8 bytes) to
// |file|. It flushes |file|, so after this call there are no writings pending.
// |file| must be previously opened.
// Writes `length` doubles from `buffer` in binary representation (8 bytes) to
// `file`. It flushes `file`, so after this call there are no writings pending.
// `file` must be previously opened.
// Returns the number of doubles written or -1 on error.
size_t WriteDoubleBufferToFile(FileWrapper* file,
size_t length,

View File

@ -26,13 +26,13 @@ namespace webrtc {
// the last values of the moments. When needed.
class MovingMoments {
public:
// Creates a Moving Moments object, that uses the last |length| values
// Creates a Moving Moments object, that uses the last `length` values
// (including the new value introduced in every new calculation).
explicit MovingMoments(size_t length);
~MovingMoments();
// Calculates the new values using |in|. Results will be in the out buffers.
// |first| and |second| must be allocated with at least |in_length|.
// Calculates the new values using `in`. Results will be in the out buffers.
// `first` and `second` must be allocated with at least `in_length`.
void CalculateMoments(const float* in,
size_t in_length,
float* first,
@ -40,7 +40,7 @@ class MovingMoments {
private:
size_t length_;
// A queue holding the |length_| latest input values.
// A queue holding the `length_` latest input values.
std::queue<float> queue_;
// Sum of the values of the queue.
float sum_;

View File

@ -43,8 +43,8 @@ TransientDetector::TransientDetector(int sample_rate_hz)
sample_rate_hz == ts::kSampleRate48kHz);
int samples_per_transient = sample_rate_hz * kTransientLengthMs / 1000;
// Adjustment to avoid data loss while downsampling, making
// |samples_per_chunk_| and |samples_per_transient| always divisible by
// |kLeaves|.
// `samples_per_chunk_` and `samples_per_transient` always divisible by
// `kLeaves`.
samples_per_chunk_ -= samples_per_chunk_ % kLeaves;
samples_per_transient -= samples_per_transient % kLeaves;
@ -137,7 +137,7 @@ float TransientDetector::Detect(const float* data,
// In the current implementation we return the max of the current result and
// the previous results, so the high results have a width equals to
// |transient_length|.
// `transient_length`.
return *std::max_element(previous_results_.begin(), previous_results_.end());
}

View File

@ -37,8 +37,8 @@ class TransientDetector {
~TransientDetector();
// Calculates the log-likelihood of the existence of a transient in |data|.
// |data_length| has to be equal to |samples_per_chunk_|.
// Calculates the log-likelihood of the existence of a transient in `data`.
// `data_length` has to be equal to `samples_per_chunk_`.
// Returns a value between 0 and 1, as a non linear representation of this
// likelihood.
// Returns a negative value on error.
@ -71,7 +71,7 @@ class TransientDetector {
float last_second_moment_[kLeaves];
// We keep track of the previous results from the previous chunks, so it can
// be used to effectively give results according to the |transient_length|.
// be used to effectively give results according to the `transient_length`.
std::deque<float> previous_results_;
// Number of chunks that are going to return only zeros at the beginning of

View File

@ -11,9 +11,7 @@
#ifndef MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_H_
#define MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_H_
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include <cstddef>
namespace webrtc {
@ -21,38 +19,55 @@ namespace webrtc {
// restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressor {
public:
// Type of VAD used by the caller to compute the `voice_probability` argument
// `Suppress()`.
enum class VadMode {
// By default, `TransientSuppressor` assumes that `voice_probability` is
// computed by `AgcManagerDirect`.
kDefault = 0,
// Use this mode when `TransientSuppressor` must assume that
// `voice_probability` is computed by the RNN VAD.
kRnnVad,
// Use this mode to let `TransientSuppressor::Suppressor()` ignore
// `voice_probability` and behave as if voice information is unavailable
// (regardless of the passed value).
kNoVad,
};
virtual ~TransientSuppressor() {}
virtual int Initialize(int sample_rate_hz,
int detector_rate_hz,
int num_channels) = 0;
virtual void Initialize(int sample_rate_hz,
int detector_rate_hz,
int num_channels) = 0;
// Processes a |data| chunk, and returns it with keystrokes suppressed from
// Processes a `data` chunk, and returns it with keystrokes suppressed from
// it. The float format is assumed to be int16 ranged. If there are more than
// one channel, the chunks are concatenated one after the other in |data|.
// |data_length| must be equal to |data_length_|.
// |num_channels| must be equal to |num_channels_|.
// A sub-band, ideally the higher, can be used as |detection_data|. If it is
// NULL, |data| is used for the detection too. The |detection_data| is always
// one channel, the chunks are concatenated one after the other in `data`.
// `data_length` must be equal to `data_length_`.
// `num_channels` must be equal to `num_channels_`.
// A sub-band, ideally the higher, can be used as `detection_data`. If it is
// NULL, `data` is used for the detection too. The `detection_data` is always
// assumed mono.
// If a reference signal (e.g. keyboard microphone) is available, it can be
// passed in as |reference_data|. It is assumed mono and must have the same
// length as |data|. NULL is accepted if unavailable.
// passed in as `reference_data`. It is assumed mono and must have the same
// length as `data`. NULL is accepted if unavailable.
// This suppressor performs better if voice information is available.
// |voice_probability| is the probability of voice being present in this chunk
// of audio. If voice information is not available, |voice_probability| must
// `voice_probability` is the probability of voice being present in this chunk
// of audio. If voice information is not available, `voice_probability` must
// always be set to 1.
// |key_pressed| determines if a key was pressed on this audio chunk.
// Returns 0 on success and -1 otherwise.
virtual int Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) = 0;
// `key_pressed` determines if a key was pressed on this audio chunk.
// Returns a delayed version of `voice_probability` according to the
// algorithmic delay introduced by this method. In this way, the modified
// `data` and the returned voice probability will be temporally aligned.
virtual float Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) = 0;
};
} // namespace webrtc

View File

@ -18,6 +18,7 @@
#include <deque>
#include <limits>
#include <set>
#include <string>
#include "common_audio/include/audio_util.h"
#include "common_audio/signal_processing/include/signal_processing_library.h"
@ -32,7 +33,6 @@
namespace webrtc {
static const float kMeanIIRCoefficient = 0.5f;
static const float kVoiceThreshold = 0.02f;
// TODO(aluebs): Check if these values work also for 48kHz.
static const size_t kMinVoiceBin = 3;
@ -44,10 +44,27 @@ float ComplexMagnitude(float a, float b) {
return std::abs(a) + std::abs(b);
}
std::string GetVadModeLabel(TransientSuppressor::VadMode vad_mode) {
switch (vad_mode) {
case TransientSuppressor::VadMode::kDefault:
return "default";
case TransientSuppressor::VadMode::kRnnVad:
return "RNN VAD";
case TransientSuppressor::VadMode::kNoVad:
return "no VAD";
}
}
} // namespace
TransientSuppressorImpl::TransientSuppressorImpl()
: data_length_(0),
TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode,
int sample_rate_hz,
int detector_rate_hz,
int num_channels)
: vad_mode_(vad_mode),
voice_probability_delay_unit_(/*delay_num_samples=*/0, sample_rate_hz),
analyzed_audio_is_silent_(false),
data_length_(0),
detection_length_(0),
analysis_length_(0),
buffer_delay_(0),
@ -62,13 +79,26 @@ TransientSuppressorImpl::TransientSuppressorImpl()
use_hard_restoration_(false),
chunks_since_voice_change_(0),
seed_(182),
using_reference_(false) {}
using_reference_(false) {
RTC_LOG(LS_INFO) << "VAD mode: " << GetVadModeLabel(vad_mode_);
Initialize(sample_rate_hz, detector_rate_hz, num_channels);
}
TransientSuppressorImpl::~TransientSuppressorImpl() {}
int TransientSuppressorImpl::Initialize(int sample_rate_hz,
int detection_rate_hz,
int num_channels) {
void TransientSuppressorImpl::Initialize(int sample_rate_hz,
int detection_rate_hz,
int num_channels) {
RTC_DCHECK(sample_rate_hz == ts::kSampleRate8kHz ||
sample_rate_hz == ts::kSampleRate16kHz ||
sample_rate_hz == ts::kSampleRate32kHz ||
sample_rate_hz == ts::kSampleRate48kHz);
RTC_DCHECK(detection_rate_hz == ts::kSampleRate8kHz ||
detection_rate_hz == ts::kSampleRate16kHz ||
detection_rate_hz == ts::kSampleRate32kHz ||
detection_rate_hz == ts::kSampleRate48kHz);
RTC_DCHECK_GT(num_channels, 0);
switch (sample_rate_hz) {
case ts::kSampleRate8kHz:
analysis_length_ = 128u;
@ -87,26 +117,18 @@ int TransientSuppressorImpl::Initialize(int sample_rate_hz,
window_ = kBlocks480w1024;
break;
default:
return -1;
}
if (detection_rate_hz != ts::kSampleRate8kHz &&
detection_rate_hz != ts::kSampleRate16kHz &&
detection_rate_hz != ts::kSampleRate32kHz &&
detection_rate_hz != ts::kSampleRate48kHz) {
return -1;
}
if (num_channels <= 0) {
return -1;
RTC_DCHECK_NOTREACHED();
return;
}
detector_.reset(new TransientDetector(detection_rate_hz));
data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
if (data_length_ > analysis_length_) {
RTC_NOTREACHED();
return -1;
}
RTC_DCHECK_LE(data_length_, analysis_length_);
buffer_delay_ = analysis_length_ - data_length_;
voice_probability_delay_unit_.Initialize(/*delay_num_samples=*/buffer_delay_,
sample_rate_hz);
complex_analysis_length_ = analysis_length_ / 2 + 1;
RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin);
num_channels_ = num_channels;
@ -155,28 +177,28 @@ int TransientSuppressorImpl::Initialize(int sample_rate_hz,
chunks_since_voice_change_ = 0;
seed_ = 182;
using_reference_ = false;
return 0;
}
int TransientSuppressorImpl::Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) {
float TransientSuppressorImpl::Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) {
if (!data || data_length != data_length_ || num_channels != num_channels_ ||
detection_length != detection_length_ || voice_probability < 0 ||
voice_probability > 1) {
return -1;
// The audio is not modified, so the voice probability is returned as is
// (delay not applied).
return voice_probability;
}
UpdateKeypress(key_pressed);
UpdateBuffers(data);
int result = 0;
if (detection_enabled_) {
UpdateRestoration(voice_probability);
@ -189,12 +211,14 @@ int TransientSuppressorImpl::Suppress(float* data,
float detector_result = detector_->Detect(detection_data, detection_length,
reference_data, reference_length);
if (detector_result < 0) {
return -1;
// The audio is not modified, so the voice probability is returned as is
// (delay not applied).
return voice_probability;
}
using_reference_ = detector_->using_reference();
// |detector_smoothed_| follows the |detector_result| when this last one is
// `detector_smoothed_` follows the `detector_result` when this last one is
// increasing, but has an exponential decaying tail to be able to suppress
// the ringing of keyclicks.
float smooth_factor = using_reference_ ? 0.6 : 0.1;
@ -219,11 +243,13 @@ int TransientSuppressorImpl::Suppress(float* data,
: &in_buffer_[i * analysis_length_],
data_length_ * sizeof(*data));
}
return result;
// The audio has been modified, return the delayed voice probability.
return voice_probability_delay_unit_.Delay(voice_probability);
}
// This should only be called when detection is enabled. UpdateBuffers() must
// have been called. At return, |out_buffer_| will be filled with the
// have been called. At return, `out_buffer_` will be filled with the
// processed output.
void TransientSuppressorImpl::Suppress(float* in_ptr,
float* spectral_mean,
@ -304,16 +330,34 @@ void TransientSuppressorImpl::UpdateKeypress(bool key_pressed) {
}
void TransientSuppressorImpl::UpdateRestoration(float voice_probability) {
const int kHardRestorationOffsetDelay = 3;
const int kHardRestorationOnsetDelay = 80;
bool not_voiced = voice_probability < kVoiceThreshold;
bool not_voiced;
switch (vad_mode_) {
case TransientSuppressor::VadMode::kDefault: {
constexpr float kVoiceThreshold = 0.02f;
not_voiced = voice_probability < kVoiceThreshold;
break;
}
case TransientSuppressor::VadMode::kRnnVad: {
constexpr float kVoiceThreshold = 0.7f;
not_voiced = voice_probability < kVoiceThreshold;
break;
}
case TransientSuppressor::VadMode::kNoVad:
// Always assume that voice is detected.
not_voiced = false;
break;
}
if (not_voiced == use_hard_restoration_) {
chunks_since_voice_change_ = 0;
} else {
++chunks_since_voice_change_;
// Number of 10 ms frames to wait to transition to and from hard
// restoration.
constexpr int kHardRestorationOffsetDelay = 3;
constexpr int kHardRestorationOnsetDelay = 80;
if ((use_hard_restoration_ &&
chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
(!use_hard_restoration_ &&
@ -325,7 +369,7 @@ void TransientSuppressorImpl::UpdateRestoration(float voice_probability) {
}
// Shift buffers to make way for new data. Must be called after
// |detection_enabled_| is updated by UpdateKeypress().
// `detection_enabled_` is updated by UpdateKeypress().
void TransientSuppressorImpl::UpdateBuffers(float* data) {
// TODO(aluebs): Change to ring buffer.
memmove(in_buffer_.get(), &in_buffer_[data_length_],
@ -350,9 +394,9 @@ void TransientSuppressorImpl::UpdateBuffers(float* data) {
}
// Restores the unvoiced signal if a click is present.
// Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
// the spectral mean. The attenuation depends on |detector_smoothed_|.
// If a restoration takes place, the |magnitudes_| are updated to the new value.
// Attenuates by a certain factor every peak in the `fft_buffer_` that exceeds
// the spectral mean. The attenuation depends on `detector_smoothed_`.
// If a restoration takes place, the `magnitudes_` are updated to the new value.
void TransientSuppressorImpl::HardRestoration(float* spectral_mean) {
const float detector_result =
1.f - std::pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
@ -376,10 +420,10 @@ void TransientSuppressorImpl::HardRestoration(float* spectral_mean) {
}
// Restores the voiced signal if a click is present.
// Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
// Attenuates by a certain factor every peak in the `fft_buffer_` that exceeds
// the spectral mean and that is lower than some function of the current block
// frequency mean. The attenuation depends on |detector_smoothed_|.
// If a restoration takes place, the |magnitudes_| are updated to the new value.
// frequency mean. The attenuation depends on `detector_smoothed_`.
// If a restoration takes place, the `magnitudes_` are updated to the new value.
void TransientSuppressorImpl::SoftRestoration(float* spectral_mean) {
// Get the spectral magnitude mean of the current block.
float block_frequency_mean = 0;

View File

@ -17,6 +17,7 @@
#include <memory>
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/voice_probability_delay_unit.h"
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
@ -27,42 +28,28 @@ class TransientDetector;
// restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressorImpl : public TransientSuppressor {
public:
TransientSuppressorImpl();
TransientSuppressorImpl(VadMode vad_mode,
int sample_rate_hz,
int detector_rate_hz,
int num_channels);
~TransientSuppressorImpl() override;
int Initialize(int sample_rate_hz,
int detector_rate_hz,
int num_channels) override;
void Initialize(int sample_rate_hz,
int detector_rate_hz,
int num_channels) override;
// Processes a |data| chunk, and returns it with keystrokes suppressed from
// it. The float format is assumed to be int16 ranged. If there are more than
// one channel, the chunks are concatenated one after the other in |data|.
// |data_length| must be equal to |data_length_|.
// |num_channels| must be equal to |num_channels_|.
// A sub-band, ideally the higher, can be used as |detection_data|. If it is
// NULL, |data| is used for the detection too. The |detection_data| is always
// assumed mono.
// If a reference signal (e.g. keyboard microphone) is available, it can be
// passed in as |reference_data|. It is assumed mono and must have the same
// length as |data|. NULL is accepted if unavailable.
// This suppressor performs better if voice information is available.
// |voice_probability| is the probability of voice being present in this chunk
// of audio. If voice information is not available, |voice_probability| must
// always be set to 1.
// |key_pressed| determines if a key was pressed on this audio chunk.
// Returns 0 on success and -1 otherwise.
int Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) override;
float Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) override;
private:
FRIEND_TEST_ALL_PREFIXES(TransientSuppressorImplTest,
FRIEND_TEST_ALL_PREFIXES(TransientSuppressorVadModeParametrization,
TypingDetectionLogicWorksAsExpectedForMono);
void Suppress(float* in_ptr, float* spectral_mean, float* out_ptr);
@ -74,8 +61,13 @@ class TransientSuppressorImpl : public TransientSuppressor {
void HardRestoration(float* spectral_mean);
void SoftRestoration(float* spectral_mean);
const VadMode vad_mode_;
VoiceProbabilityDelayUnit voice_probability_delay_unit_;
std::unique_ptr<TransientDetector> detector_;
bool analyzed_audio_is_silent_;
size_t data_length_;
size_t detection_length_;
size_t analysis_length_;

View File

@ -0,0 +1,56 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/transient/voice_probability_delay_unit.h"
#include <array>
#include "rtc_base/checks.h"
namespace webrtc {
VoiceProbabilityDelayUnit::VoiceProbabilityDelayUnit(int delay_num_samples,
int sample_rate_hz) {
Initialize(delay_num_samples, sample_rate_hz);
}
void VoiceProbabilityDelayUnit::Initialize(int delay_num_samples,
int sample_rate_hz) {
RTC_DCHECK_GE(delay_num_samples, 0);
RTC_DCHECK_LE(delay_num_samples, sample_rate_hz / 50)
<< "The implementation does not support delays greater than 20 ms.";
int frame_size = rtc::CheckedDivExact(sample_rate_hz, 100); // 10 ms.
if (delay_num_samples <= frame_size) {
weights_[0] = 0.0f;
weights_[1] = static_cast<float>(delay_num_samples) / frame_size;
weights_[2] =
static_cast<float>(frame_size - delay_num_samples) / frame_size;
} else {
delay_num_samples -= frame_size;
weights_[0] = static_cast<float>(delay_num_samples) / frame_size;
weights_[1] =
static_cast<float>(frame_size - delay_num_samples) / frame_size;
weights_[2] = 0.0f;
}
// Resets the delay unit.
last_probabilities_.fill(0.0f);
}
float VoiceProbabilityDelayUnit::Delay(float voice_probability) {
float weighted_probability = weights_[0] * last_probabilities_[0] +
weights_[1] * last_probabilities_[1] +
weights_[2] * voice_probability;
last_probabilities_[0] = last_probabilities_[1];
last_probabilities_[1] = voice_probability;
return weighted_probability;
}
} // namespace webrtc

View File

@ -0,0 +1,43 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_TRANSIENT_VOICE_PROBABILITY_DELAY_UNIT_H_
#define MODULES_AUDIO_PROCESSING_TRANSIENT_VOICE_PROBABILITY_DELAY_UNIT_H_
#include <array>
namespace webrtc {
// Iteratively produces a sequence of delayed voice probability values given a
// fixed delay between 0 and 20 ms and given a sequence of voice probability
// values observed every 10 ms. Supports fractional delays, that are delays
// which are not a multiple integer of 10 ms. Applies interpolation with
// fractional delays; otherwise, returns a previously observed value according
// to the given fixed delay.
class VoiceProbabilityDelayUnit {
public:
// Ctor. `delay_num_samples` is the delay in number of samples and it must be
// non-negative and less than 20 ms.
VoiceProbabilityDelayUnit(int delay_num_samples, int sample_rate_hz);
// Handles delay and sample rate changes and resets the delay unit.
void Initialize(int delay_num_samples, int sample_rate_hz);
// Observes `voice_probability` and returns a delayed voice probability.
float Delay(float voice_probability);
private:
std::array<float, 3> weights_;
std::array<float, 2> last_probabilities_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_TRANSIENT_VOICE_PROBABILITY_DELAY_UNIT_H_

View File

@ -11,6 +11,7 @@
#ifndef MODULES_AUDIO_PROCESSING_TRANSIENT_WPD_NODE_H_
#define MODULES_AUDIO_PROCESSING_TRANSIENT_WPD_NODE_H_
#include <cstddef>
#include <memory>
namespace webrtc {
@ -25,7 +26,7 @@ class WPDNode {
WPDNode(size_t length, const float* coefficients, size_t coefficients_length);
~WPDNode();
// Updates the node data. |parent_data| / 2 must be equals to |length_|.
// Updates the node data. `parent_data` / 2 must be equals to `length_`.
// Returns 0 if correct, and -1 otherwise.
int Update(const float* parent_data, size_t parent_data_length);

View File

@ -65,7 +65,7 @@ class WPDTree {
// If level or index are out of bounds the function will return NULL.
WPDNode* NodeAt(int level, int index);
// Updates all the nodes of the tree with the new data. |data_length| must be
// Updates all the nodes of the tree with the new data. `data_length` must be
// teh same that was used for the creation of the tree.
// Returns 0 if correct, and -1 otherwise.
int Update(const float* data, size_t data_length);