mirror of
https://github.com/esphome/esphome.git
synced 2024-12-22 13:34:54 +01:00
Add on_client_connected and disconnected to voice assistant (#5629)
This commit is contained in:
parent
69ec647f7e
commit
193bac94f4
9 changed files with 98 additions and 51 deletions
|
@ -18,6 +18,8 @@ from esphome.const import (
|
|||
CONF_TRIGGER_ID,
|
||||
CONF_EVENT,
|
||||
CONF_TAG,
|
||||
CONF_ON_CLIENT_CONNECTED,
|
||||
CONF_ON_CLIENT_DISCONNECTED,
|
||||
)
|
||||
from esphome.core import coroutine_with_priority
|
||||
|
||||
|
@ -45,8 +47,6 @@ SERVICE_ARG_NATIVE_TYPES = {
|
|||
"string[]": cg.std_vector.template(cg.std_string),
|
||||
}
|
||||
CONF_ENCRYPTION = "encryption"
|
||||
CONF_ON_CLIENT_CONNECTED = "on_client_connected"
|
||||
CONF_ON_CLIENT_DISCONNECTED = "on_client_disconnected"
|
||||
|
||||
|
||||
def validate_encryption_key(value):
|
||||
|
|
|
@ -60,6 +60,11 @@ APIConnection::~APIConnection() {
|
|||
bluetooth_proxy::global_bluetooth_proxy->unsubscribe_api_connection(this);
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_VOICE_ASSISTANT
|
||||
if (voice_assistant::global_voice_assistant->get_api_connection() == this) {
|
||||
voice_assistant::global_voice_assistant->client_subscription(this, false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void APIConnection::loop() {
|
||||
|
@ -950,14 +955,17 @@ BluetoothConnectionsFreeResponse APIConnection::subscribe_bluetooth_connections_
|
|||
#endif
|
||||
|
||||
#ifdef USE_VOICE_ASSISTANT
|
||||
bool APIConnection::request_voice_assistant(const VoiceAssistantRequest &msg) {
|
||||
if (!this->voice_assistant_subscription_)
|
||||
return false;
|
||||
|
||||
return this->send_voice_assistant_request(msg);
|
||||
void APIConnection::subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) {
|
||||
if (voice_assistant::global_voice_assistant != nullptr) {
|
||||
voice_assistant::global_voice_assistant->client_subscription(this, msg.subscribe);
|
||||
}
|
||||
}
|
||||
void APIConnection::on_voice_assistant_response(const VoiceAssistantResponse &msg) {
|
||||
if (voice_assistant::global_voice_assistant != nullptr) {
|
||||
if (voice_assistant::global_voice_assistant->get_api_connection() != this) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (msg.error) {
|
||||
voice_assistant::global_voice_assistant->failed_to_start();
|
||||
return;
|
||||
|
@ -970,6 +978,10 @@ void APIConnection::on_voice_assistant_response(const VoiceAssistantResponse &ms
|
|||
};
|
||||
void APIConnection::on_voice_assistant_event_response(const VoiceAssistantEventResponse &msg) {
|
||||
if (voice_assistant::global_voice_assistant != nullptr) {
|
||||
if (voice_assistant::global_voice_assistant->get_api_connection() != this) {
|
||||
return;
|
||||
}
|
||||
|
||||
voice_assistant::global_voice_assistant->on_event(msg);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -126,10 +126,7 @@ class APIConnection : public APIServerConnection {
|
|||
#endif
|
||||
|
||||
#ifdef USE_VOICE_ASSISTANT
|
||||
void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) override {
|
||||
this->voice_assistant_subscription_ = msg.subscribe;
|
||||
}
|
||||
bool request_voice_assistant(const VoiceAssistantRequest &msg);
|
||||
void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) override;
|
||||
void on_voice_assistant_response(const VoiceAssistantResponse &msg) override;
|
||||
void on_voice_assistant_event_response(const VoiceAssistantEventResponse &msg) override;
|
||||
#endif
|
||||
|
@ -188,6 +185,8 @@ class APIConnection : public APIServerConnection {
|
|||
}
|
||||
bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override;
|
||||
|
||||
std::string get_client_combined_info() const { return this->client_combined_info_; }
|
||||
|
||||
protected:
|
||||
friend APIServer;
|
||||
|
||||
|
@ -220,9 +219,6 @@ class APIConnection : public APIServerConnection {
|
|||
uint32_t last_traffic_;
|
||||
bool sent_ping_{false};
|
||||
bool service_call_subscription_{false};
|
||||
#ifdef USE_VOICE_ASSISTANT
|
||||
bool voice_assistant_subscription_{false};
|
||||
#endif
|
||||
bool next_close_ = false;
|
||||
APIServer *parent_;
|
||||
InitialStateIterator initial_state_iterator_;
|
||||
|
|
|
@ -332,30 +332,6 @@ void APIServer::on_shutdown() {
|
|||
delay(10);
|
||||
}
|
||||
|
||||
#ifdef USE_VOICE_ASSISTANT
|
||||
bool APIServer::start_voice_assistant(const std::string &conversation_id, uint32_t flags,
|
||||
const api::VoiceAssistantAudioSettings &audio_settings) {
|
||||
VoiceAssistantRequest msg;
|
||||
msg.start = true;
|
||||
msg.conversation_id = conversation_id;
|
||||
msg.flags = flags;
|
||||
msg.audio_settings = audio_settings;
|
||||
for (auto &c : this->clients_) {
|
||||
if (c->request_voice_assistant(msg))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
void APIServer::stop_voice_assistant() {
|
||||
VoiceAssistantRequest msg;
|
||||
msg.start = false;
|
||||
for (auto &c : this->clients_) {
|
||||
if (c->request_voice_assistant(msg))
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ALARM_CONTROL_PANEL
|
||||
void APIServer::on_alarm_control_panel_update(alarm_control_panel::AlarmControlPanel *obj) {
|
||||
if (obj->is_internal())
|
||||
|
|
|
@ -84,12 +84,6 @@ class APIServer : public Component, public Controller {
|
|||
void request_time();
|
||||
#endif
|
||||
|
||||
#ifdef USE_VOICE_ASSISTANT
|
||||
bool start_voice_assistant(const std::string &conversation_id, uint32_t flags,
|
||||
const api::VoiceAssistantAudioSettings &audio_settings);
|
||||
void stop_voice_assistant();
|
||||
#endif
|
||||
|
||||
#ifdef USE_ALARM_CONTROL_PANEL
|
||||
void on_alarm_control_panel_update(alarm_control_panel::AlarmControlPanel *obj) override;
|
||||
#endif
|
||||
|
|
|
@ -6,6 +6,8 @@ from esphome.const import (
|
|||
CONF_MICROPHONE,
|
||||
CONF_SPEAKER,
|
||||
CONF_MEDIA_PLAYER,
|
||||
CONF_ON_CLIENT_CONNECTED,
|
||||
CONF_ON_CLIENT_DISCONNECTED,
|
||||
)
|
||||
from esphome import automation
|
||||
from esphome.automation import register_action, register_condition
|
||||
|
@ -80,6 +82,12 @@ CONFIG_SCHEMA = cv.All(
|
|||
cv.Optional(CONF_ON_TTS_END): automation.validate_automation(single=True),
|
||||
cv.Optional(CONF_ON_END): automation.validate_automation(single=True),
|
||||
cv.Optional(CONF_ON_ERROR): automation.validate_automation(single=True),
|
||||
cv.Optional(CONF_ON_CLIENT_CONNECTED): automation.validate_automation(
|
||||
single=True
|
||||
),
|
||||
cv.Optional(CONF_ON_CLIENT_DISCONNECTED): automation.validate_automation(
|
||||
single=True
|
||||
),
|
||||
}
|
||||
).extend(cv.COMPONENT_SCHEMA),
|
||||
)
|
||||
|
@ -155,6 +163,20 @@ async def to_code(config):
|
|||
config[CONF_ON_ERROR],
|
||||
)
|
||||
|
||||
if CONF_ON_CLIENT_CONNECTED in config:
|
||||
await automation.build_automation(
|
||||
var.get_client_connected_trigger(),
|
||||
[],
|
||||
config[CONF_ON_CLIENT_CONNECTED],
|
||||
)
|
||||
|
||||
if CONF_ON_CLIENT_DISCONNECTED in config:
|
||||
await automation.build_automation(
|
||||
var.get_client_disconnected_trigger(),
|
||||
[],
|
||||
config[CONF_ON_CLIENT_DISCONNECTED],
|
||||
)
|
||||
|
||||
cg.add_define("USE_VOICE_ASSISTANT")
|
||||
|
||||
|
||||
|
|
|
@ -127,8 +127,8 @@ int VoiceAssistant::read_microphone_() {
|
|||
}
|
||||
|
||||
void VoiceAssistant::loop() {
|
||||
if (this->state_ != State::IDLE && this->state_ != State::STOP_MICROPHONE &&
|
||||
this->state_ != State::STOPPING_MICROPHONE && !api::global_api_server->is_connected()) {
|
||||
if (this->api_client_ == nullptr && this->state_ != State::IDLE && this->state_ != State::STOP_MICROPHONE &&
|
||||
this->state_ != State::STOPPING_MICROPHONE) {
|
||||
if (this->mic_->is_running() || this->state_ == State::STARTING_MICROPHONE) {
|
||||
this->set_state_(State::STOP_MICROPHONE, State::IDLE);
|
||||
} else {
|
||||
|
@ -213,7 +213,14 @@ void VoiceAssistant::loop() {
|
|||
audio_settings.noise_suppression_level = this->noise_suppression_level_;
|
||||
audio_settings.auto_gain = this->auto_gain_;
|
||||
audio_settings.volume_multiplier = this->volume_multiplier_;
|
||||
if (!api::global_api_server->start_voice_assistant(this->conversation_id_, flags, audio_settings)) {
|
||||
|
||||
api::VoiceAssistantRequest msg;
|
||||
msg.start = true;
|
||||
msg.conversation_id = this->conversation_id_;
|
||||
msg.flags = flags;
|
||||
msg.audio_settings = audio_settings;
|
||||
|
||||
if (this->api_client_ == nullptr || !this->api_client_->send_voice_assistant_request(msg)) {
|
||||
ESP_LOGW(TAG, "Could not request start.");
|
||||
this->error_trigger_->trigger("not-connected", "Could not request start.");
|
||||
this->continuous_ = false;
|
||||
|
@ -326,6 +333,28 @@ void VoiceAssistant::loop() {
|
|||
}
|
||||
}
|
||||
|
||||
void VoiceAssistant::client_subscription(api::APIConnection *client, bool subscribe) {
|
||||
if (!subscribe) {
|
||||
if (this->api_client_ == nullptr || client != this->api_client_) {
|
||||
ESP_LOGE(TAG, "Client attempting to unsubscribe that is not the current API Client");
|
||||
return;
|
||||
}
|
||||
this->api_client_ = nullptr;
|
||||
this->client_disconnected_trigger_->trigger();
|
||||
return;
|
||||
}
|
||||
|
||||
if (this->api_client_ != nullptr) {
|
||||
ESP_LOGE(TAG, "Multiple API Clients attempting to connect to Voice Assistant");
|
||||
ESP_LOGE(TAG, "Current client: %s", this->api_client_->get_client_combined_info().c_str());
|
||||
ESP_LOGE(TAG, "New client: %s", client->get_client_combined_info().c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
this->api_client_ = client;
|
||||
this->client_connected_trigger_->trigger();
|
||||
}
|
||||
|
||||
static const LogString *voice_assistant_state_to_string(State state) {
|
||||
switch (state) {
|
||||
case State::IDLE:
|
||||
|
@ -408,7 +437,7 @@ void VoiceAssistant::start_streaming(struct sockaddr_storage *addr, uint16_t por
|
|||
}
|
||||
|
||||
void VoiceAssistant::request_start(bool continuous, bool silence_detection) {
|
||||
if (!api::global_api_server->is_connected()) {
|
||||
if (this->api_client_ == nullptr) {
|
||||
ESP_LOGE(TAG, "No API client connected");
|
||||
this->set_state_(State::IDLE, State::IDLE);
|
||||
this->continuous_ = false;
|
||||
|
@ -459,9 +488,14 @@ void VoiceAssistant::request_stop() {
|
|||
}
|
||||
|
||||
void VoiceAssistant::signal_stop_() {
|
||||
ESP_LOGD(TAG, "Signaling stop...");
|
||||
api::global_api_server->stop_voice_assistant();
|
||||
memset(&this->dest_addr_, 0, sizeof(this->dest_addr_));
|
||||
if (this->api_client_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
ESP_LOGD(TAG, "Signaling stop...");
|
||||
api::VoiceAssistantRequest msg;
|
||||
msg.start = false;
|
||||
this->api_client_->send_voice_assistant_request(msg);
|
||||
}
|
||||
|
||||
void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
|
||||
|
|
|
@ -8,8 +8,8 @@
|
|||
#include "esphome/core/component.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
|
||||
#include "esphome/components/api/api_connection.h"
|
||||
#include "esphome/components/api/api_pb2.h"
|
||||
#include "esphome/components/api/api_server.h"
|
||||
#include "esphome/components/microphone/microphone.h"
|
||||
#ifdef USE_SPEAKER
|
||||
#include "esphome/components/speaker/speaker.h"
|
||||
|
@ -109,6 +109,12 @@ class VoiceAssistant : public Component {
|
|||
Trigger<> *get_end_trigger() const { return this->end_trigger_; }
|
||||
Trigger<std::string, std::string> *get_error_trigger() const { return this->error_trigger_; }
|
||||
|
||||
Trigger<> *get_client_connected_trigger() const { return this->client_connected_trigger_; }
|
||||
Trigger<> *get_client_disconnected_trigger() const { return this->client_disconnected_trigger_; }
|
||||
|
||||
void client_subscription(api::APIConnection *client, bool subscribe);
|
||||
api::APIConnection *get_api_connection() const { return this->api_client_; }
|
||||
|
||||
protected:
|
||||
int read_microphone_();
|
||||
void set_state_(State state);
|
||||
|
@ -127,6 +133,11 @@ class VoiceAssistant : public Component {
|
|||
Trigger<> *end_trigger_ = new Trigger<>();
|
||||
Trigger<std::string, std::string> *error_trigger_ = new Trigger<std::string, std::string>();
|
||||
|
||||
Trigger<> *client_connected_trigger_ = new Trigger<>();
|
||||
Trigger<> *client_disconnected_trigger_ = new Trigger<>();
|
||||
|
||||
api::APIConnection *api_client_{nullptr};
|
||||
|
||||
microphone::Microphone *mic_{nullptr};
|
||||
#ifdef USE_SPEAKER
|
||||
speaker::Speaker *speaker_{nullptr};
|
||||
|
|
|
@ -485,6 +485,8 @@ CONF_ON_BLE_MANUFACTURER_DATA_ADVERTISE = "on_ble_manufacturer_data_advertise"
|
|||
CONF_ON_BLE_SERVICE_DATA_ADVERTISE = "on_ble_service_data_advertise"
|
||||
CONF_ON_BOOT = "on_boot"
|
||||
CONF_ON_CLICK = "on_click"
|
||||
CONF_ON_CLIENT_CONNECTED = "on_client_connected"
|
||||
CONF_ON_CLIENT_DISCONNECTED = "on_client_disconnected"
|
||||
CONF_ON_CONNECT = "on_connect"
|
||||
CONF_ON_CONTROL = "on_control"
|
||||
CONF_ON_DISCONNECT = "on_disconnect"
|
||||
|
|
Loading…
Reference in a new issue