Voice Assistant improvements (#5827)

This commit is contained in:
Jesse Hills 2023-11-27 13:45:26 +13:00
parent 1aa49c8956
commit 8e4b9c3c1e
No known key found for this signature in database
GPG key ID: BEAAE804EFD8E83A
5 changed files with 78 additions and 39 deletions

View file

@ -220,6 +220,8 @@ size_t I2SAudioSpeaker::play(const uint8_t *data, size_t length) {
return index;
}
bool I2SAudioSpeaker::has_buffered_data() const { return uxQueueMessagesWaiting(this->buffer_queue_) > 0; }
} // namespace i2s_audio
} // namespace esphome

View file

@ -56,6 +56,8 @@ class I2SAudioSpeaker : public Component, public speaker::Speaker, public I2SAud
size_t play(const uint8_t *data, size_t length) override;
bool has_buffered_data() const override;
protected:
void start_();
// void stop_();

View file

@ -18,6 +18,8 @@ class Speaker {
virtual void start() = 0;
virtual void stop() = 0;
virtual bool has_buffered_data() const = 0;
bool is_running() const { return this->state_ == STATE_RUNNING; }
protected:

View file

@ -273,28 +273,27 @@ void VoiceAssistant::loop() {
bool playing = false;
#ifdef USE_SPEAKER
if (this->speaker_ != nullptr) {
ssize_t received_len = 0;
if (this->speaker_buffer_index_ + RECEIVE_SIZE < SPEAKER_BUFFER_SIZE) {
auto len = this->socket_->read(this->speaker_buffer_ + this->speaker_buffer_index_, RECEIVE_SIZE);
if (len > 0) {
this->speaker_buffer_index_ += len;
this->speaker_buffer_size_ += len;
received_len = this->socket_->read(this->speaker_buffer_ + this->speaker_buffer_index_, RECEIVE_SIZE);
if (received_len > 0) {
this->speaker_buffer_index_ += received_len;
this->speaker_buffer_size_ += received_len;
this->speaker_bytes_received_ += received_len;
}
} else {
ESP_LOGW(TAG, "Receive buffer full");
}
if (this->speaker_buffer_size_ > 0) {
size_t written = this->speaker_->play(this->speaker_buffer_, this->speaker_buffer_size_);
if (written > 0) {
memmove(this->speaker_buffer_, this->speaker_buffer_ + written, this->speaker_buffer_size_ - written);
this->speaker_buffer_size_ -= written;
this->speaker_buffer_index_ -= written;
this->set_timeout("speaker-timeout", 2000, [this]() { this->speaker_->stop(); });
} else {
ESP_LOGW(TAG, "Speaker buffer full");
}
ESP_LOGD(TAG, "Receive buffer full");
}
// Build a small buffer of audio before sending to the speaker
if (this->speaker_bytes_received_ > RECEIVE_SIZE * 4)
this->write_speaker_();
if (this->wait_for_stream_end_) {
this->cancel_timeout("playing");
if (this->stream_ended_ && received_len < 0) {
ESP_LOGD(TAG, "End of audio stream received");
this->cancel_timeout("speaker-timeout");
this->set_state_(State::RESPONSE_FINISHED, State::RESPONSE_FINISHED);
}
break; // We dont want to timeout here as the STREAM_END event will take care of that.
}
playing = this->speaker_->is_running();
@ -316,14 +315,26 @@ void VoiceAssistant::loop() {
case State::RESPONSE_FINISHED: {
#ifdef USE_SPEAKER
if (this->speaker_ != nullptr) {
if (this->speaker_buffer_size_ > 0) {
this->write_speaker_();
break;
}
if (this->speaker_->has_buffered_data() || this->speaker_->is_running()) {
break;
}
ESP_LOGD(TAG, "Speaker has finished outputting all audio");
this->speaker_->stop();
this->cancel_timeout("speaker-timeout");
this->cancel_timeout("playing");
this->speaker_buffer_size_ = 0;
this->speaker_buffer_index_ = 0;
this->speaker_bytes_received_ = 0;
memset(this->speaker_buffer_, 0, SPEAKER_BUFFER_SIZE);
}
this->wait_for_stream_end_ = false;
this->stream_ended_ = false;
this->tts_stream_end_trigger_->trigger();
}
#endif
this->set_state_(State::IDLE, State::IDLE);
break;
@ -333,6 +344,20 @@ void VoiceAssistant::loop() {
}
}
void VoiceAssistant::write_speaker_() {
if (this->speaker_buffer_size_ > 0) {
size_t written = this->speaker_->play(this->speaker_buffer_, this->speaker_buffer_size_);
if (written > 0) {
memmove(this->speaker_buffer_, this->speaker_buffer_ + written, this->speaker_buffer_size_ - written);
this->speaker_buffer_size_ -= written;
this->speaker_buffer_index_ -= written;
this->set_timeout("speaker-timeout", 5000, [this]() { this->speaker_->stop(); });
} else {
ESP_LOGD(TAG, "Speaker buffer full, trying again next loop");
}
}
}
void VoiceAssistant::client_subscription(api::APIConnection *client, bool subscribe) {
if (!subscribe) {
if (this->api_client_ == nullptr || client != this->api_client_) {
@ -503,21 +528,20 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
switch (msg.event_type) {
case api::enums::VOICE_ASSISTANT_RUN_START:
ESP_LOGD(TAG, "Assist Pipeline running");
this->start_trigger_->trigger();
this->defer([this]() { this->start_trigger_->trigger(); });
break;
case api::enums::VOICE_ASSISTANT_WAKE_WORD_START:
break;
case api::enums::VOICE_ASSISTANT_WAKE_WORD_END: {
ESP_LOGD(TAG, "Wake word detected");
this->wake_word_detected_trigger_->trigger();
this->defer([this]() { this->wake_word_detected_trigger_->trigger(); });
break;
}
case api::enums::VOICE_ASSISTANT_STT_START:
ESP_LOGD(TAG, "STT started");
this->listening_trigger_->trigger();
this->defer([this]() { this->listening_trigger_->trigger(); });
break;
case api::enums::VOICE_ASSISTANT_STT_END: {
this->set_state_(State::STOP_MICROPHONE, State::AWAITING_RESPONSE);
std::string text;
for (auto arg : msg.data) {
if (arg.name == "text") {
@ -529,12 +553,12 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
return;
}
ESP_LOGD(TAG, "Speech recognised as: \"%s\"", text.c_str());
this->stt_end_trigger_->trigger(text);
this->defer([this, text]() { this->stt_end_trigger_->trigger(text); });
break;
}
case api::enums::VOICE_ASSISTANT_INTENT_START:
ESP_LOGD(TAG, "Intent started");
this->intent_start_trigger_->trigger();
this->defer([this]() { this->intent_start_trigger_->trigger(); });
break;
case api::enums::VOICE_ASSISTANT_INTENT_END: {
for (auto arg : msg.data) {
@ -542,7 +566,7 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
this->conversation_id_ = std::move(arg.value);
}
}
this->intent_end_trigger_->trigger();
this->defer([this]() { this->intent_end_trigger_->trigger(); });
break;
}
case api::enums::VOICE_ASSISTANT_TTS_START: {
@ -557,10 +581,12 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
return;
}
ESP_LOGD(TAG, "Response: \"%s\"", text.c_str());
this->defer([this, text]() {
this->tts_start_trigger_->trigger(text);
#ifdef USE_SPEAKER
this->speaker_->start();
#endif
});
break;
}
case api::enums::VOICE_ASSISTANT_TTS_END: {
@ -575,14 +601,16 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
return;
}
ESP_LOGD(TAG, "Response URL: \"%s\"", url.c_str());
this->defer([this, url]() {
#ifdef USE_MEDIA_PLAYER
if (this->media_player_ != nullptr) {
this->media_player_->make_call().set_media_url(url).perform();
}
#endif
this->tts_end_trigger_->trigger(url);
});
State new_state = this->local_output_ ? State::STREAMING_RESPONSE : State::IDLE;
this->set_state_(new_state, new_state);
this->tts_end_trigger_->trigger(url);
break;
}
case api::enums::VOICE_ASSISTANT_RUN_END: {
@ -599,7 +627,7 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
this->set_state_(State::IDLE, State::IDLE);
}
}
this->end_trigger_->trigger();
this->defer([this]() { this->end_trigger_->trigger(); });
break;
}
case api::enums::VOICE_ASSISTANT_ERROR: {
@ -617,8 +645,10 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
return;
} else if (code == "wake-provider-missing" || code == "wake-engine-missing") {
// Wake word is not set up or not ready on Home Assistant so stop and do not retry until user starts again.
this->defer([this, code, message]() {
this->request_stop();
this->error_trigger_->trigger(code, message);
});
return;
}
ESP_LOGE(TAG, "Error: %s - %s", code.c_str(), message.c_str());
@ -626,32 +656,32 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
this->signal_stop_();
this->set_state_(State::STOP_MICROPHONE, State::IDLE);
}
this->error_trigger_->trigger(code, message);
this->defer([this, code, message]() { this->error_trigger_->trigger(code, message); });
break;
}
case api::enums::VOICE_ASSISTANT_TTS_STREAM_START: {
#ifdef USE_SPEAKER
this->wait_for_stream_end_ = true;
ESP_LOGD(TAG, "TTS stream start");
this->tts_stream_start_trigger_->trigger();
this->defer([this] { this->tts_stream_start_trigger_->trigger(); });
#endif
break;
}
case api::enums::VOICE_ASSISTANT_TTS_STREAM_END: {
this->set_state_(State::RESPONSE_FINISHED, State::IDLE);
#ifdef USE_SPEAKER
this->stream_ended_ = true;
ESP_LOGD(TAG, "TTS stream end");
this->tts_stream_end_trigger_->trigger();
#endif
break;
}
case api::enums::VOICE_ASSISTANT_STT_VAD_START:
ESP_LOGD(TAG, "Starting STT by VAD");
this->stt_vad_start_trigger_->trigger();
this->defer([this]() { this->stt_vad_start_trigger_->trigger(); });
break;
case api::enums::VOICE_ASSISTANT_STT_VAD_END:
ESP_LOGD(TAG, "STT by VAD end");
this->stt_vad_end_trigger_->trigger();
this->set_state_(State::STOP_MICROPHONE, State::AWAITING_RESPONSE);
this->defer([this]() { this->stt_vad_end_trigger_->trigger(); });
break;
default:
ESP_LOGD(TAG, "Unhandled event type: %d", msg.event_type);

View file

@ -156,11 +156,14 @@ class VoiceAssistant : public Component {
microphone::Microphone *mic_{nullptr};
#ifdef USE_SPEAKER
void write_speaker_();
speaker::Speaker *speaker_{nullptr};
uint8_t *speaker_buffer_;
size_t speaker_buffer_index_{0};
size_t speaker_buffer_size_{0};
size_t speaker_bytes_received_{0};
bool wait_for_stream_end_{false};
bool stream_ended_{false};
#endif
#ifdef USE_MEDIA_PLAYER
media_player::MediaPlayer *media_player_{nullptr};