Some API deprecation -- ExperimentalAgc and ExperimentalNs are gone. We're continuing to carry iSAC even though it's gone upstream, but maybe we'll want to drop that soon.
199 lines
8.2 KiB
C++
199 lines
8.2 KiB
C++
/*
|
|
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
|
*
|
|
* Use of this source code is governed by a BSD-style license
|
|
* that can be found in the LICENSE file in the root of the source
|
|
* tree. An additional intellectual property rights grant can be found
|
|
* in the file PATENTS. All contributing project authors may
|
|
* be found in the AUTHORS file in the root of the source tree.
|
|
*/
|
|
|
|
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
|
|
|
|
#include "rtc_base/checks.h"
|
|
#include "rtc_base/numerics/safe_conversions.h"
|
|
#include "third_party/rnnoise/src/rnn_activations.h"
|
|
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
|
|
|
namespace webrtc {
|
|
namespace rnn_vad {
|
|
namespace {
|
|
|
|
constexpr int kNumGruGates = 3; // Update, reset, output.
|
|
|
|
std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
|
|
int output_size) {
|
|
// Transpose, cast and scale.
|
|
// `n` is the size of the first dimension of the 3-dim tensor `weights`.
|
|
const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
|
|
output_size * kNumGruGates);
|
|
const int stride_src = kNumGruGates * output_size;
|
|
const int stride_dst = n * output_size;
|
|
std::vector<float> tensor_dst(tensor_src.size());
|
|
for (int g = 0; g < kNumGruGates; ++g) {
|
|
for (int o = 0; o < output_size; ++o) {
|
|
for (int i = 0; i < n; ++i) {
|
|
tensor_dst[g * stride_dst + o * n + i] =
|
|
::rnnoise::kWeightsScale *
|
|
static_cast<float>(
|
|
tensor_src[i * stride_src + g * output_size + o]);
|
|
}
|
|
}
|
|
}
|
|
return tensor_dst;
|
|
}
|
|
|
|
// Computes the output for the update or the reset gate.
|
|
// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
|
|
// - `g`: output gate vector
|
|
// - `W`: weights matrix
|
|
// - `i`: input vector
|
|
// - `R`: recurrent weights matrix
|
|
// - `s`: state gate vector
|
|
// - `b`: bias vector
|
|
void ComputeUpdateResetGate(int input_size,
|
|
int output_size,
|
|
const VectorMath& vector_math,
|
|
rtc::ArrayView<const float> input,
|
|
rtc::ArrayView<const float> state,
|
|
rtc::ArrayView<const float> bias,
|
|
rtc::ArrayView<const float> weights,
|
|
rtc::ArrayView<const float> recurrent_weights,
|
|
rtc::ArrayView<float> gate) {
|
|
RTC_DCHECK_EQ(input.size(), input_size);
|
|
RTC_DCHECK_EQ(state.size(), output_size);
|
|
RTC_DCHECK_EQ(bias.size(), output_size);
|
|
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
|
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
|
|
RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated.
|
|
for (int o = 0; o < output_size; ++o) {
|
|
float x = bias[o];
|
|
x += vector_math.DotProduct(input,
|
|
weights.subview(o * input_size, input_size));
|
|
x += vector_math.DotProduct(
|
|
state, recurrent_weights.subview(o * output_size, output_size));
|
|
gate[o] = ::rnnoise::SigmoidApproximated(x);
|
|
}
|
|
}
|
|
|
|
// Computes the output for the state gate.
|
|
// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
|
|
// - `s'`: output state gate vector
|
|
// - `s`: previous state gate vector
|
|
// - `u`: update gate vector
|
|
// - `W`: weights matrix
|
|
// - `i`: input vector
|
|
// - `R`: recurrent weights matrix
|
|
// - `r`: reset gate vector
|
|
// - `b`: bias vector
|
|
// - `.*` element-wise product
|
|
void ComputeStateGate(int input_size,
|
|
int output_size,
|
|
const VectorMath& vector_math,
|
|
rtc::ArrayView<const float> input,
|
|
rtc::ArrayView<const float> update,
|
|
rtc::ArrayView<const float> reset,
|
|
rtc::ArrayView<const float> bias,
|
|
rtc::ArrayView<const float> weights,
|
|
rtc::ArrayView<const float> recurrent_weights,
|
|
rtc::ArrayView<float> state) {
|
|
RTC_DCHECK_EQ(input.size(), input_size);
|
|
RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated.
|
|
RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated.
|
|
RTC_DCHECK_EQ(bias.size(), output_size);
|
|
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
|
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
|
|
RTC_DCHECK_EQ(state.size(), output_size);
|
|
std::array<float, kGruLayerMaxUnits> reset_x_state;
|
|
for (int o = 0; o < output_size; ++o) {
|
|
reset_x_state[o] = state[o] * reset[o];
|
|
}
|
|
for (int o = 0; o < output_size; ++o) {
|
|
float x = bias[o];
|
|
x += vector_math.DotProduct(input,
|
|
weights.subview(o * input_size, input_size));
|
|
x += vector_math.DotProduct(
|
|
{reset_x_state.data(), static_cast<size_t>(output_size)},
|
|
recurrent_weights.subview(o * output_size, output_size));
|
|
state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
GatedRecurrentLayer::GatedRecurrentLayer(
|
|
const int input_size,
|
|
const int output_size,
|
|
const rtc::ArrayView<const int8_t> bias,
|
|
const rtc::ArrayView<const int8_t> weights,
|
|
const rtc::ArrayView<const int8_t> recurrent_weights,
|
|
const AvailableCpuFeatures& cpu_features,
|
|
absl::string_view layer_name)
|
|
: input_size_(input_size),
|
|
output_size_(output_size),
|
|
bias_(PreprocessGruTensor(bias, output_size)),
|
|
weights_(PreprocessGruTensor(weights, output_size)),
|
|
recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
|
|
vector_math_(cpu_features) {
|
|
RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
|
|
<< "Insufficient GRU layer over-allocation (" << layer_name << ").";
|
|
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
|
|
<< "Mismatching output size and bias terms array size (" << layer_name
|
|
<< ").";
|
|
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
|
|
<< "Mismatching input-output size and weight coefficients array size ("
|
|
<< layer_name << ").";
|
|
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
|
|
recurrent_weights_.size())
|
|
<< "Mismatching input-output size and recurrent weight coefficients array"
|
|
" size ("
|
|
<< layer_name << ").";
|
|
Reset();
|
|
}
|
|
|
|
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
|
|
|
|
void GatedRecurrentLayer::Reset() {
|
|
state_.fill(0.f);
|
|
}
|
|
|
|
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
|
RTC_DCHECK_EQ(input.size(), input_size_);
|
|
|
|
// The tensors below are organized as a sequence of flattened tensors for the
|
|
// `update`, `reset` and `state` gates.
|
|
rtc::ArrayView<const float> bias(bias_);
|
|
rtc::ArrayView<const float> weights(weights_);
|
|
rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
|
|
// Strides to access to the flattened tensors for a specific gate.
|
|
const int stride_weights = input_size_ * output_size_;
|
|
const int stride_recurrent_weights = output_size_ * output_size_;
|
|
|
|
rtc::ArrayView<float> state(state_.data(), output_size_);
|
|
|
|
// Update gate.
|
|
std::array<float, kGruLayerMaxUnits> update;
|
|
ComputeUpdateResetGate(
|
|
input_size_, output_size_, vector_math_, input, state,
|
|
bias.subview(0, output_size_), weights.subview(0, stride_weights),
|
|
recurrent_weights.subview(0, stride_recurrent_weights), update);
|
|
// Reset gate.
|
|
std::array<float, kGruLayerMaxUnits> reset;
|
|
ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
|
|
bias.subview(output_size_, output_size_),
|
|
weights.subview(stride_weights, stride_weights),
|
|
recurrent_weights.subview(stride_recurrent_weights,
|
|
stride_recurrent_weights),
|
|
reset);
|
|
// State gate.
|
|
ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
|
|
reset, bias.subview(2 * output_size_, output_size_),
|
|
weights.subview(2 * stride_weights, stride_weights),
|
|
recurrent_weights.subview(2 * stride_recurrent_weights,
|
|
stride_recurrent_weights),
|
|
state);
|
|
}
|
|
|
|
} // namespace rnn_vad
|
|
} // namespace webrtc
|