From 27b593ba85d26be8445b7868dc56db51304e198b Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Mon, 30 Oct 2023 15:02:49 +1300 Subject: [PATCH] Add connection triggers to api (#5628) --- esphome/components/api/__init__.py | 22 ++++++++ esphome/components/api/api_connection.cpp | 62 +++++++++++++---------- esphome/components/api/api_connection.h | 2 + esphome/components/api/api_server.cpp | 1 + esphome/components/api/api_server.h | 8 +++ 5 files changed, 68 insertions(+), 27 deletions(-) diff --git a/esphome/components/api/__init__.py b/esphome/components/api/__init__.py index 1076ebc707..ec1a56bd2c 100644 --- a/esphome/components/api/__init__.py +++ b/esphome/components/api/__init__.py @@ -45,6 +45,8 @@ SERVICE_ARG_NATIVE_TYPES = { "string[]": cg.std_vector.template(cg.std_string), } CONF_ENCRYPTION = "encryption" +CONF_ON_CLIENT_CONNECTED = "on_client_connected" +CONF_ON_CLIENT_DISCONNECTED = "on_client_disconnected" def validate_encryption_key(value): @@ -87,6 +89,12 @@ CONFIG_SCHEMA = cv.Schema( cv.Required(CONF_KEY): validate_encryption_key, } ), + cv.Optional(CONF_ON_CLIENT_CONNECTED): automation.validate_automation( + single=True + ), + cv.Optional(CONF_ON_CLIENT_DISCONNECTED): automation.validate_automation( + single=True + ), } ).extend(cv.COMPONENT_SCHEMA) @@ -116,6 +124,20 @@ async def to_code(config): cg.add(var.register_user_service(trigger)) await automation.build_automation(trigger, func_args, conf) + if CONF_ON_CLIENT_CONNECTED in config: + await automation.build_automation( + var.get_client_connected_trigger(), + [(cg.std_string, "client_info"), (cg.std_string, "client_address")], + config[CONF_ON_CLIENT_CONNECTED], + ) + + if CONF_ON_CLIENT_DISCONNECTED in config: + await automation.build_automation( + var.get_client_disconnected_trigger(), + [(cg.std_string, "client_info"), (cg.std_string, "client_address")], + config[CONF_ON_CLIENT_DISCONNECTED], + ) + if encryption_config := config.get(CONF_ENCRYPTION): decoded = base64.b64decode(encryption_config[CONF_KEY]) cg.add(var.set_noise_psk(list(decoded))) diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 03eaa159c7..d1e7513d11 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -32,9 +32,9 @@ APIConnection::APIConnection(std::unique_ptr sock, APIServer *pa this->proto_write_buffer_.reserve(64); #if defined(USE_API_PLAINTEXT) - helper_ = std::unique_ptr{new APIPlaintextFrameHelper(std::move(sock))}; + this->helper_ = std::unique_ptr{new APIPlaintextFrameHelper(std::move(sock))}; #elif defined(USE_API_NOISE) - helper_ = std::unique_ptr{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())}; + this->helper_ = std::unique_ptr{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())}; #else #error "No frame helper defined" #endif @@ -42,14 +42,16 @@ APIConnection::APIConnection(std::unique_ptr sock, APIServer *pa void APIConnection::start() { this->last_traffic_ = millis(); - APIError err = helper_->init(); + APIError err = this->helper_->init(); if (err != APIError::OK) { on_fatal_error(); - ESP_LOGW(TAG, "%s: Helper init failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); + ESP_LOGW(TAG, "%s: Helper init failed: %s errno=%d", this->client_combined_info_.c_str(), api_error_to_str(err), + errno); return; } - client_info_ = helper_->getpeername(); - helper_->set_log_info(client_info_); + this->client_info_ = helper_->getpeername(); + this->client_peername_ = this->client_info_; + this->helper_->set_log_info(this->client_info_); } APIConnection::~APIConnection() { @@ -68,7 +70,7 @@ void APIConnection::loop() { // when network is disconnected force disconnect immediately // don't wait for timeout this->on_fatal_error(); - ESP_LOGW(TAG, "%s: Network unavailable, disconnecting", client_info_.c_str()); + ESP_LOGW(TAG, "%s: Network unavailable, disconnecting", this->client_combined_info_.c_str()); return; } if (this->next_close_) { @@ -78,24 +80,26 @@ void APIConnection::loop() { return; } - APIError err = helper_->loop(); + APIError err = this->helper_->loop(); if (err != APIError::OK) { on_fatal_error(); - ESP_LOGW(TAG, "%s: Socket operation failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); + ESP_LOGW(TAG, "%s: Socket operation failed: %s errno=%d", this->client_combined_info_.c_str(), + api_error_to_str(err), errno); return; } ReadPacketBuffer buffer; - err = helper_->read_packet(&buffer); + err = this->helper_->read_packet(&buffer); if (err == APIError::WOULD_BLOCK) { // pass } else if (err != APIError::OK) { on_fatal_error(); if (err == APIError::SOCKET_READ_FAILED && errno == ECONNRESET) { - ESP_LOGW(TAG, "%s: Connection reset", client_info_.c_str()); + ESP_LOGW(TAG, "%s: Connection reset", this->client_combined_info_.c_str()); } else if (err == APIError::CONNECTION_CLOSED) { - ESP_LOGW(TAG, "%s: Connection closed", client_info_.c_str()); + ESP_LOGW(TAG, "%s: Connection closed", this->client_combined_info_.c_str()); } else { - ESP_LOGW(TAG, "%s: Reading failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); + ESP_LOGW(TAG, "%s: Reading failed: %s errno=%d", this->client_combined_info_.c_str(), api_error_to_str(err), + errno); } return; } else { @@ -115,7 +119,7 @@ void APIConnection::loop() { // Disconnect if not responded within 2.5*keepalive if (now - this->last_traffic_ > (keepalive * 5) / 2) { on_fatal_error(); - ESP_LOGW(TAG, "%s didn't respond to ping request in time. Disconnecting...", this->client_info_.c_str()); + ESP_LOGW(TAG, "%s didn't respond to ping request in time. Disconnecting...", this->client_combined_info_.c_str()); } } else if (now - this->last_traffic_ > keepalive) { ESP_LOGVV(TAG, "Sending keepalive PING..."); @@ -169,7 +173,7 @@ DisconnectResponse APIConnection::disconnect(const DisconnectRequest &msg) { // remote initiated disconnect_client // don't close yet, we still need to send the disconnect response // close will happen on next loop - ESP_LOGD(TAG, "%s requested disconnected", client_info_.c_str()); + ESP_LOGD(TAG, "%s requested disconnected", this->client_combined_info_.c_str()); this->next_close_ = true; DisconnectResponse resp; return resp; @@ -1045,12 +1049,14 @@ bool APIConnection::send_log_message(int level, const char *tag, const char *lin } HelloResponse APIConnection::hello(const HelloRequest &msg) { - this->client_info_ = msg.client_info + " (" + this->helper_->getpeername() + ")"; - this->helper_->set_log_info(client_info_); + this->client_info_ = msg.client_info; + this->client_peername_ = this->helper_->getpeername(); + this->client_combined_info_ = this->client_info_ + " (" + this->client_peername_ + ")"; + this->helper_->set_log_info(this->client_combined_info_); this->client_api_version_major_ = msg.api_version_major; this->client_api_version_minor_ = msg.api_version_minor; - ESP_LOGV(TAG, "Hello from client: '%s' | API Version %" PRIu32 ".%" PRIu32, this->client_info_.c_str(), - this->client_api_version_major_, this->client_api_version_minor_); + ESP_LOGV(TAG, "Hello from client: '%s' | %s | API Version %" PRIu32 ".%" PRIu32, this->client_info_.c_str(), + this->client_peername_.c_str(), this->client_api_version_major_, this->client_api_version_minor_); HelloResponse resp; resp.api_version_major = 1; @@ -1068,9 +1074,9 @@ ConnectResponse APIConnection::connect(const ConnectRequest &msg) { // bool invalid_password = 1; resp.invalid_password = !correct; if (correct) { - ESP_LOGD(TAG, "%s: Connected successfully", this->client_info_.c_str()); + ESP_LOGD(TAG, "%s: Connected successfully", this->client_combined_info_.c_str()); this->connection_state_ = ConnectionState::AUTHENTICATED; - + this->parent_->get_client_connected_trigger()->trigger(this->client_info_, this->client_peername_); #ifdef USE_HOMEASSISTANT_TIME if (homeassistant::global_homeassistant_time != nullptr) { this->send_time_request(); @@ -1145,10 +1151,11 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) return false; if (!this->helper_->can_write_without_blocking()) { delay(0); - APIError err = helper_->loop(); + APIError err = this->helper_->loop(); if (err != APIError::OK) { on_fatal_error(); - ESP_LOGW(TAG, "%s: Socket operation failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); + ESP_LOGW(TAG, "%s: Socket operation failed: %s errno=%d", this->client_combined_info_.c_str(), + api_error_to_str(err), errno); return false; } if (!this->helper_->can_write_without_blocking()) { @@ -1167,9 +1174,10 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) if (err != APIError::OK) { on_fatal_error(); if (err == APIError::SOCKET_WRITE_FAILED && errno == ECONNRESET) { - ESP_LOGW(TAG, "%s: Connection reset", client_info_.c_str()); + ESP_LOGW(TAG, "%s: Connection reset", this->client_combined_info_.c_str()); } else { - ESP_LOGW(TAG, "%s: Packet write failed %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); + ESP_LOGW(TAG, "%s: Packet write failed %s errno=%d", this->client_combined_info_.c_str(), api_error_to_str(err), + errno); } return false; } @@ -1178,11 +1186,11 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) } void APIConnection::on_unauthenticated_access() { this->on_fatal_error(); - ESP_LOGD(TAG, "%s: tried to access without authentication.", this->client_info_.c_str()); + ESP_LOGD(TAG, "%s: tried to access without authentication.", this->client_combined_info_.c_str()); } void APIConnection::on_no_setup_connection() { this->on_fatal_error(); - ESP_LOGD(TAG, "%s: tried to access without full connection.", this->client_info_.c_str()); + ESP_LOGD(TAG, "%s: tried to access without full connection.", this->client_combined_info_.c_str()); } void APIConnection::on_fatal_error() { this->helper_->close(); diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index c17aaab611..21ee85daab 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -207,6 +207,8 @@ class APIConnection : public APIServerConnection { std::unique_ptr helper_; std::string client_info_; + std::string client_peername_; + std::string client_combined_info_; uint32_t client_api_version_major_{0}; uint32_t client_api_version_minor_{0}; #ifdef USE_ESP32_CAMERA diff --git a/esphome/components/api/api_server.cpp b/esphome/components/api/api_server.cpp index c4edddc92b..5268b30132 100644 --- a/esphome/components/api/api_server.cpp +++ b/esphome/components/api/api_server.cpp @@ -111,6 +111,7 @@ void APIServer::loop() { [](const std::unique_ptr &conn) { return !conn->remove_; }); // print disconnection messages for (auto it = new_end; it != this->clients_.end(); ++it) { + this->client_disconnected_trigger_->trigger((*it)->client_info_, (*it)->client_peername_); ESP_LOGV(TAG, "Removing connection to %s", (*it)->client_info_.c_str()); } // resize vector diff --git a/esphome/components/api/api_server.h b/esphome/components/api/api_server.h index 4d359ebb79..f1fb31fa8b 100644 --- a/esphome/components/api/api_server.h +++ b/esphome/components/api/api_server.h @@ -4,6 +4,7 @@ #include "api_pb2.h" #include "api_pb2_service.h" #include "esphome/components/socket/socket.h" +#include "esphome/core/automation.h" #include "esphome/core/component.h" #include "esphome/core/controller.h" #include "esphome/core/defines.h" @@ -106,6 +107,11 @@ class APIServer : public Component, public Controller { const std::vector &get_state_subs() const; const std::vector &get_user_services() const { return this->user_services_; } + Trigger *get_client_connected_trigger() const { return this->client_connected_trigger_; } + Trigger *get_client_disconnected_trigger() const { + return this->client_disconnected_trigger_; + } + protected: std::unique_ptr socket_ = nullptr; uint16_t port_{6053}; @@ -115,6 +121,8 @@ class APIServer : public Component, public Controller { std::string password_; std::vector state_subs_; std::vector user_services_; + Trigger *client_connected_trigger_ = new Trigger(); + Trigger *client_disconnected_trigger_ = new Trigger(); #ifdef USE_API_NOISE std::shared_ptr noise_ctx_ = std::make_shared();