Bump to WebRTC M120 release
Some API deprecation -- ExperimentalAgc and ExperimentalNs are gone. We're continuing to carry iSAC even though it's gone upstream, but maybe we'll want to drop that soon.
This commit is contained in:
198
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
Normal file
198
webrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
Normal file
@ -0,0 +1,198 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
constexpr int kNumGruGates = 3; // Update, reset, output.
|
||||
|
||||
std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
|
||||
int output_size) {
|
||||
// Transpose, cast and scale.
|
||||
// `n` is the size of the first dimension of the 3-dim tensor `weights`.
|
||||
const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
|
||||
output_size * kNumGruGates);
|
||||
const int stride_src = kNumGruGates * output_size;
|
||||
const int stride_dst = n * output_size;
|
||||
std::vector<float> tensor_dst(tensor_src.size());
|
||||
for (int g = 0; g < kNumGruGates; ++g) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
tensor_dst[g * stride_dst + o * n + i] =
|
||||
::rnnoise::kWeightsScale *
|
||||
static_cast<float>(
|
||||
tensor_src[i * stride_src + g * output_size + o]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tensor_dst;
|
||||
}
|
||||
|
||||
// Computes the output for the update or the reset gate.
|
||||
// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
|
||||
// - `g`: output gate vector
|
||||
// - `W`: weights matrix
|
||||
// - `i`: input vector
|
||||
// - `R`: recurrent weights matrix
|
||||
// - `s`: state gate vector
|
||||
// - `b`: bias vector
|
||||
void ComputeUpdateResetGate(int input_size,
|
||||
int output_size,
|
||||
const VectorMath& vector_math,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> state,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<float> gate) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_EQ(state.size(), output_size);
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
|
||||
RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated.
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
float x = bias[o];
|
||||
x += vector_math.DotProduct(input,
|
||||
weights.subview(o * input_size, input_size));
|
||||
x += vector_math.DotProduct(
|
||||
state, recurrent_weights.subview(o * output_size, output_size));
|
||||
gate[o] = ::rnnoise::SigmoidApproximated(x);
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the output for the state gate.
|
||||
// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
|
||||
// - `s'`: output state gate vector
|
||||
// - `s`: previous state gate vector
|
||||
// - `u`: update gate vector
|
||||
// - `W`: weights matrix
|
||||
// - `i`: input vector
|
||||
// - `R`: recurrent weights matrix
|
||||
// - `r`: reset gate vector
|
||||
// - `b`: bias vector
|
||||
// - `.*` element-wise product
|
||||
void ComputeStateGate(int input_size,
|
||||
int output_size,
|
||||
const VectorMath& vector_math,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> update,
|
||||
rtc::ArrayView<const float> reset,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<float> state) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated.
|
||||
RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated.
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
|
||||
RTC_DCHECK_EQ(state.size(), output_size);
|
||||
std::array<float, kGruLayerMaxUnits> reset_x_state;
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
reset_x_state[o] = state[o] * reset[o];
|
||||
}
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
float x = bias[o];
|
||||
x += vector_math.DotProduct(input,
|
||||
weights.subview(o * input_size, input_size));
|
||||
x += vector_math.DotProduct(
|
||||
{reset_x_state.data(), static_cast<size_t>(output_size)},
|
||||
recurrent_weights.subview(o * output_size, output_size));
|
||||
state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
const int input_size,
|
||||
const int output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
absl::string_view layer_name)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(PreprocessGruTensor(bias, output_size)),
|
||||
weights_(PreprocessGruTensor(weights, output_size)),
|
||||
recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
|
||||
vector_math_(cpu_features) {
|
||||
RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
|
||||
<< "Insufficient GRU layer over-allocation (" << layer_name << ").";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
|
||||
<< "Mismatching output size and bias terms array size (" << layer_name
|
||||
<< ").";
|
||||
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
|
||||
<< "Mismatching input-output size and weight coefficients array size ("
|
||||
<< layer_name << ").";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
|
||||
recurrent_weights_.size())
|
||||
<< "Mismatching input-output size and recurrent weight coefficients array"
|
||||
" size ("
|
||||
<< layer_name << ").";
|
||||
Reset();
|
||||
}
|
||||
|
||||
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
|
||||
|
||||
void GatedRecurrentLayer::Reset() {
|
||||
state_.fill(0.f);
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size_);
|
||||
|
||||
// The tensors below are organized as a sequence of flattened tensors for the
|
||||
// `update`, `reset` and `state` gates.
|
||||
rtc::ArrayView<const float> bias(bias_);
|
||||
rtc::ArrayView<const float> weights(weights_);
|
||||
rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
|
||||
// Strides to access to the flattened tensors for a specific gate.
|
||||
const int stride_weights = input_size_ * output_size_;
|
||||
const int stride_recurrent_weights = output_size_ * output_size_;
|
||||
|
||||
rtc::ArrayView<float> state(state_.data(), output_size_);
|
||||
|
||||
// Update gate.
|
||||
std::array<float, kGruLayerMaxUnits> update;
|
||||
ComputeUpdateResetGate(
|
||||
input_size_, output_size_, vector_math_, input, state,
|
||||
bias.subview(0, output_size_), weights.subview(0, stride_weights),
|
||||
recurrent_weights.subview(0, stride_recurrent_weights), update);
|
||||
// Reset gate.
|
||||
std::array<float, kGruLayerMaxUnits> reset;
|
||||
ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
|
||||
bias.subview(output_size_, output_size_),
|
||||
weights.subview(stride_weights, stride_weights),
|
||||
recurrent_weights.subview(stride_recurrent_weights,
|
||||
stride_recurrent_weights),
|
||||
reset);
|
||||
// State gate.
|
||||
ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
|
||||
reset, bias.subview(2 * output_size_, output_size_),
|
||||
weights.subview(2 * stride_weights, stride_weights),
|
||||
recurrent_weights.subview(2 * stride_recurrent_weights,
|
||||
stride_recurrent_weights),
|
||||
state);
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
Reference in New Issue
Block a user