mirror of
https://github.com/esphome/esphome.git
synced 2024-12-22 21:44:55 +01:00
microWakeWord - add new ops and small improvements (#6360)
This commit is contained in:
parent
d121fa5d05
commit
9e378189c3
2 changed files with 27 additions and 42 deletions
|
@ -93,11 +93,18 @@ int MicroWakeWord::read_microphone_() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t bytes_written = this->ring_buffer_->write((void *) this->input_buffer_, bytes_read);
|
size_t bytes_free = this->ring_buffer_->free();
|
||||||
if (bytes_written != bytes_read) {
|
|
||||||
ESP_LOGW(TAG, "Failed to write some data to ring buffer (written=%d, expected=%d)", bytes_written, bytes_read);
|
if (bytes_free < bytes_read) {
|
||||||
|
ESP_LOGW(TAG,
|
||||||
|
"Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). "
|
||||||
|
"Resetting the ring buffer. Wake word detection accuracy will be reduced.",
|
||||||
|
bytes_free, bytes_read);
|
||||||
|
|
||||||
|
this->ring_buffer_->reset();
|
||||||
}
|
}
|
||||||
return bytes_written;
|
|
||||||
|
return this->ring_buffer_->write((void *) this->input_buffer_, bytes_read);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MicroWakeWord::loop() {
|
void MicroWakeWord::loop() {
|
||||||
|
@ -206,12 +213,6 @@ bool MicroWakeWord::initialize_models() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
this->preprocessor_stride_buffer_ = audio_samples_allocator.allocate(HISTORY_SAMPLES_TO_KEEP);
|
|
||||||
if (this->preprocessor_stride_buffer_ == nullptr) {
|
|
||||||
ESP_LOGE(TAG, "Could not allocate the audio preprocessor's stride buffer.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
this->preprocessor_model_ = tflite::GetModel(G_AUDIO_PREPROCESSOR_INT8_TFLITE);
|
this->preprocessor_model_ = tflite::GetModel(G_AUDIO_PREPROCESSOR_INT8_TFLITE);
|
||||||
if (this->preprocessor_model_->version() != TFLITE_SCHEMA_VERSION) {
|
if (this->preprocessor_model_->version() != TFLITE_SCHEMA_VERSION) {
|
||||||
ESP_LOGE(TAG, "Wake word's audio preprocessor model's schema is not supported");
|
ESP_LOGE(TAG, "Wake word's audio preprocessor model's schema is not supported");
|
||||||
|
@ -225,7 +226,7 @@ bool MicroWakeWord::initialize_models() {
|
||||||
}
|
}
|
||||||
|
|
||||||
static tflite::MicroMutableOpResolver<18> preprocessor_op_resolver;
|
static tflite::MicroMutableOpResolver<18> preprocessor_op_resolver;
|
||||||
static tflite::MicroMutableOpResolver<14> streaming_op_resolver;
|
static tflite::MicroMutableOpResolver<17> streaming_op_resolver;
|
||||||
|
|
||||||
if (!this->register_preprocessor_ops_(preprocessor_op_resolver))
|
if (!this->register_preprocessor_ops_(preprocessor_op_resolver))
|
||||||
return false;
|
return false;
|
||||||
|
@ -329,7 +330,6 @@ bool MicroWakeWord::detect_wake_word_() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform inference
|
// Perform inference
|
||||||
uint32_t streaming_size = micros();
|
|
||||||
float streaming_prob = this->perform_streaming_inference_();
|
float streaming_prob = this->perform_streaming_inference_();
|
||||||
|
|
||||||
// Add the most recent probability to the sliding window
|
// Add the most recent probability to the sliding window
|
||||||
|
@ -357,6 +357,9 @@ bool MicroWakeWord::detect_wake_word_() {
|
||||||
for (auto &prob : this->recent_streaming_probabilities_) {
|
for (auto &prob : this->recent_streaming_probabilities_) {
|
||||||
prob = 0;
|
prob = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ESP_LOGD(TAG, "Wake word sliding average probability is %.3f and most recent probability is %.3f",
|
||||||
|
sliding_window_average, streaming_prob);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -371,23 +374,6 @@ void MicroWakeWord::set_sliding_window_average_size(size_t size) {
|
||||||
bool MicroWakeWord::slice_available_() {
|
bool MicroWakeWord::slice_available_() {
|
||||||
size_t available = this->ring_buffer_->available();
|
size_t available = this->ring_buffer_->available();
|
||||||
|
|
||||||
size_t free = this->ring_buffer_->free();
|
|
||||||
|
|
||||||
if (free < NEW_SAMPLES_TO_GET * sizeof(int16_t)) {
|
|
||||||
// If the ring buffer is within one audio slice of being full, then wake word detection will have issues.
|
|
||||||
// If this is constantly occuring, then some possibilities why are
|
|
||||||
// 1) there are too many other slow components configured
|
|
||||||
// 2) the ESP32 isn't fast enough; e.g., an ESP32 is much slower than an ESP32-S3 at inferences.
|
|
||||||
// 3) the model is too large
|
|
||||||
// 4) the model uses operations that are not optimized
|
|
||||||
ESP_LOGW(TAG,
|
|
||||||
"Audio buffer is nearly full. Wake word detection may be less accurate and have slower reponse times. "
|
|
||||||
#if !defined(USE_ESP32_VARIANT_ESP32S3)
|
|
||||||
"microWakeWord is designed for the ESP32-S3. The current platform is too slow for this model."
|
|
||||||
#endif
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return available > (NEW_SAMPLES_TO_GET * sizeof(int16_t));
|
return available > (NEW_SAMPLES_TO_GET * sizeof(int16_t));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -396,13 +382,12 @@ bool MicroWakeWord::stride_audio_samples_(int16_t **audio_samples) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy 320 bytes (160 samples over 10 ms) into preprocessor_audio_buffer_ from history in
|
// Copy the last 320 bytes (160 samples over 10 ms) from the audio buffer to the start of the audio buffer
|
||||||
// preprocessor_stride_buffer_
|
memcpy((void *) (this->preprocessor_audio_buffer_), (void *) (this->preprocessor_audio_buffer_ + NEW_SAMPLES_TO_GET),
|
||||||
memcpy((void *) (this->preprocessor_audio_buffer_), (void *) (this->preprocessor_stride_buffer_),
|
|
||||||
HISTORY_SAMPLES_TO_KEEP * sizeof(int16_t));
|
HISTORY_SAMPLES_TO_KEEP * sizeof(int16_t));
|
||||||
|
|
||||||
// Copy 640 bytes (320 samples over 20 ms) from the ring buffer
|
// Copy 640 bytes (320 samples over 20 ms) from the ring buffer into the audio buffer offset 320 bytes (160 samples
|
||||||
// The first 320 bytes (160 samples over 10 ms) will be from history
|
// over 10 ms)
|
||||||
size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_ + HISTORY_SAMPLES_TO_KEEP),
|
size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_ + HISTORY_SAMPLES_TO_KEEP),
|
||||||
NEW_SAMPLES_TO_GET * sizeof(int16_t), pdMS_TO_TICKS(200));
|
NEW_SAMPLES_TO_GET * sizeof(int16_t), pdMS_TO_TICKS(200));
|
||||||
|
|
||||||
|
@ -415,11 +400,6 @@ bool MicroWakeWord::stride_audio_samples_(int16_t **audio_samples) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the last 320 bytes (160 samples over 10 ms) from the audio buffer into history stride buffer for the next
|
|
||||||
// iteration
|
|
||||||
memcpy((void *) (this->preprocessor_stride_buffer_), (void *) (this->preprocessor_audio_buffer_ + NEW_SAMPLES_TO_GET),
|
|
||||||
HISTORY_SAMPLES_TO_KEEP * sizeof(int16_t));
|
|
||||||
|
|
||||||
*audio_samples = this->preprocessor_audio_buffer_;
|
*audio_samples = this->preprocessor_audio_buffer_;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -480,7 +460,7 @@ bool MicroWakeWord::register_preprocessor_ops_(tflite::MicroMutableOpResolver<18
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<14> &op_resolver) {
|
bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<17> &op_resolver) {
|
||||||
if (op_resolver.AddCallOnce() != kTfLiteOk)
|
if (op_resolver.AddCallOnce() != kTfLiteOk)
|
||||||
return false;
|
return false;
|
||||||
if (op_resolver.AddVarHandle() != kTfLiteOk)
|
if (op_resolver.AddVarHandle() != kTfLiteOk)
|
||||||
|
@ -509,6 +489,12 @@ bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<14> &
|
||||||
return false;
|
return false;
|
||||||
if (op_resolver.AddQuantize() != kTfLiteOk)
|
if (op_resolver.AddQuantize() != kTfLiteOk)
|
||||||
return false;
|
return false;
|
||||||
|
if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
|
||||||
|
return false;
|
||||||
|
if (op_resolver.AddAveragePool2D() != kTfLiteOk)
|
||||||
|
return false;
|
||||||
|
if (op_resolver.AddMaxPool2D() != kTfLiteOk)
|
||||||
|
return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,7 +128,6 @@ class MicroWakeWord : public Component {
|
||||||
|
|
||||||
// Stores audio fed into feature generator preprocessor
|
// Stores audio fed into feature generator preprocessor
|
||||||
int16_t *preprocessor_audio_buffer_;
|
int16_t *preprocessor_audio_buffer_;
|
||||||
int16_t *preprocessor_stride_buffer_;
|
|
||||||
|
|
||||||
bool detected_{false};
|
bool detected_{false};
|
||||||
|
|
||||||
|
@ -181,7 +180,7 @@ class MicroWakeWord : public Component {
|
||||||
bool register_preprocessor_ops_(tflite::MicroMutableOpResolver<18> &op_resolver);
|
bool register_preprocessor_ops_(tflite::MicroMutableOpResolver<18> &op_resolver);
|
||||||
|
|
||||||
/// @brief Returns true if successfully registered the streaming model's TensorFlow operations
|
/// @brief Returns true if successfully registered the streaming model's TensorFlow operations
|
||||||
bool register_streaming_ops_(tflite::MicroMutableOpResolver<14> &op_resolver);
|
bool register_streaming_ops_(tflite::MicroMutableOpResolver<17> &op_resolver);
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> {
|
template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> {
|
||||||
|
|
Loading…
Reference in a new issue