Add event triggers to voice_assistant (#4699)

* Add event triggers to voice_assistant

* Add triggers to test
This commit is contained in:
Jesse Hills 2023-04-17 14:57:28 +12:00 committed by GitHub
parent 8a60919e1f
commit 3a587ea0d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 93 additions and 6 deletions

View file

@ -11,6 +11,14 @@ DEPENDENCIES = ["api", "microphone"]
CODEOWNERS = ["@jesserockz"]
CONF_ON_START = "on_start"
CONF_ON_STT_END = "on_stt_end"
CONF_ON_TTS_START = "on_tts_start"
CONF_ON_TTS_END = "on_tts_end"
CONF_ON_END = "on_end"
CONF_ON_ERROR = "on_error"
voice_assistant_ns = cg.esphome_ns.namespace("voice_assistant")
VoiceAssistant = voice_assistant_ns.class_("VoiceAssistant", cg.Component)
@ -26,6 +34,12 @@ CONFIG_SCHEMA = cv.Schema(
{
cv.GenerateID(): cv.declare_id(VoiceAssistant),
cv.GenerateID(CONF_MICROPHONE): cv.use_id(microphone.Microphone),
cv.Optional(CONF_ON_START): automation.validate_automation(single=True),
cv.Optional(CONF_ON_STT_END): automation.validate_automation(single=True),
cv.Optional(CONF_ON_TTS_START): automation.validate_automation(single=True),
cv.Optional(CONF_ON_TTS_END): automation.validate_automation(single=True),
cv.Optional(CONF_ON_END): automation.validate_automation(single=True),
cv.Optional(CONF_ON_ERROR): automation.validate_automation(single=True),
}
).extend(cv.COMPONENT_SCHEMA)
@ -37,6 +51,40 @@ async def to_code(config):
mic = await cg.get_variable(config[CONF_MICROPHONE])
cg.add(var.set_microphone(mic))
if CONF_ON_START in config:
await automation.build_automation(
var.get_start_trigger(), [], config[CONF_ON_START]
)
if CONF_ON_STT_END in config:
await automation.build_automation(
var.get_stt_end_trigger(), [(cg.std_string, "x")], config[CONF_ON_STT_END]
)
if CONF_ON_TTS_START in config:
await automation.build_automation(
var.get_tts_start_trigger(),
[(cg.std_string, "x")],
config[CONF_ON_TTS_START],
)
if CONF_ON_TTS_END in config:
await automation.build_automation(
var.get_tts_end_trigger(), [(cg.std_string, "x")], config[CONF_ON_TTS_END]
)
if CONF_ON_END in config:
await automation.build_automation(
var.get_end_trigger(), [], config[CONF_ON_END]
)
if CONF_ON_ERROR in config:
await automation.build_automation(
var.get_error_trigger(),
[(cg.std_string, "code"), (cg.std_string, "message")],
config[CONF_ON_ERROR],
)
cg.add_define("USE_VOICE_ASSISTANT")

View file

@ -76,8 +76,9 @@ void VoiceAssistant::signal_stop() {
void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
switch (msg.event_type) {
case api::enums::VOICE_ASSISTANT_RUN_END:
ESP_LOGD(TAG, "Voice Assistant ended.");
case api::enums::VOICE_ASSISTANT_RUN_START:
ESP_LOGD(TAG, "Assist Pipeline running");
this->start_trigger_->trigger();
break;
case api::enums::VOICE_ASSISTANT_STT_END: {
std::string text;
@ -91,7 +92,7 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
return;
}
ESP_LOGD(TAG, "Speech recognised as: \"%s\"", text.c_str());
// TODO `on_stt_end` trigger
this->stt_end_trigger_->trigger(text);
break;
}
case api::enums::VOICE_ASSISTANT_TTS_START: {
@ -106,7 +107,7 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
return;
}
ESP_LOGD(TAG, "Response: \"%s\"", text.c_str());
// TODO `on_tts_start` trigger
this->tts_start_trigger_->trigger(text);
break;
}
case api::enums::VOICE_ASSISTANT_TTS_END: {
@ -121,9 +122,13 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
return;
}
ESP_LOGD(TAG, "Response URL: \"%s\"", url.c_str());
// TODO `on_tts_end` trigger
this->tts_end_trigger_->trigger(url);
break;
}
case api::enums::VOICE_ASSISTANT_RUN_END:
ESP_LOGD(TAG, "Assist Pipeline ended");
this->end_trigger_->trigger();
break;
case api::enums::VOICE_ASSISTANT_ERROR: {
std::string code = "";
std::string message = "";
@ -135,7 +140,7 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) {
}
}
ESP_LOGE(TAG, "Error: %s - %s", code.c_str(), message.c_str());
// TODO `on_error` trigger
this->error_trigger_->trigger(code, message);
}
default:
break;

View file

@ -25,10 +25,24 @@ class VoiceAssistant : public Component {
void on_event(const api::VoiceAssistantEventResponse &msg);
Trigger<> *get_start_trigger() const { return this->start_trigger_; }
Trigger<std::string> *get_stt_end_trigger() const { return this->stt_end_trigger_; }
Trigger<std::string> *get_tts_start_trigger() const { return this->tts_start_trigger_; }
Trigger<std::string> *get_tts_end_trigger() const { return this->tts_end_trigger_; }
Trigger<> *get_end_trigger() const { return this->end_trigger_; }
Trigger<std::string, std::string> *get_error_trigger() const { return this->error_trigger_; }
protected:
std::unique_ptr<socket::Socket> socket_ = nullptr;
struct sockaddr_storage dest_addr_;
Trigger<> *start_trigger_ = new Trigger<>();
Trigger<std::string> *stt_end_trigger_ = new Trigger<std::string>();
Trigger<std::string> *tts_start_trigger_ = new Trigger<std::string>();
Trigger<std::string> *tts_end_trigger_ = new Trigger<std::string>();
Trigger<> *end_trigger_ = new Trigger<>();
Trigger<std::string, std::string> *error_trigger_ = new Trigger<std::string, std::string>();
microphone::Microphone *mic_{nullptr};
bool running_{false};

View file

@ -696,3 +696,23 @@ microphone:
voice_assistant:
microphone: mic_id
on_start:
- logger.log: "Voice assistant started"
on_stt_end:
- logger.log:
format: "Voice assistant STT ended with result %s"
args: [x.c_str()]
on_tts_start:
- logger.log:
format: "Voice assistant TTS started with text %s"
args: [x.c_str()]
on_tts_end:
- logger.log:
format: "Voice assistant TTS ended with url %s"
args: [x.c_str()]
on_end:
- logger.log: "Voice assistant ended"
on_error:
- logger.log:
format: "Voice assistant error - code %s, message: %s"
args: [code.c_str(), message.c_str()]