Add connection triggers to api (#5628)

This commit is contained in:
Jesse Hills 2023-10-30 15:02:49 +13:00
parent 1e0daefa16
commit 0ea4de5f4c
No known key found for this signature in database
GPG key ID: BEAAE804EFD8E83A
5 changed files with 68 additions and 27 deletions

View file

@ -45,6 +45,8 @@ SERVICE_ARG_NATIVE_TYPES = {
"string[]": cg.std_vector.template(cg.std_string), "string[]": cg.std_vector.template(cg.std_string),
} }
CONF_ENCRYPTION = "encryption" CONF_ENCRYPTION = "encryption"
CONF_ON_CLIENT_CONNECTED = "on_client_connected"
CONF_ON_CLIENT_DISCONNECTED = "on_client_disconnected"
def validate_encryption_key(value): def validate_encryption_key(value):
@ -87,6 +89,12 @@ CONFIG_SCHEMA = cv.Schema(
cv.Required(CONF_KEY): validate_encryption_key, cv.Required(CONF_KEY): validate_encryption_key,
} }
), ),
cv.Optional(CONF_ON_CLIENT_CONNECTED): automation.validate_automation(
single=True
),
cv.Optional(CONF_ON_CLIENT_DISCONNECTED): automation.validate_automation(
single=True
),
} }
).extend(cv.COMPONENT_SCHEMA) ).extend(cv.COMPONENT_SCHEMA)
@ -116,6 +124,20 @@ async def to_code(config):
cg.add(var.register_user_service(trigger)) cg.add(var.register_user_service(trigger))
await automation.build_automation(trigger, func_args, conf) await automation.build_automation(trigger, func_args, conf)
if CONF_ON_CLIENT_CONNECTED in config:
await automation.build_automation(
var.get_client_connected_trigger(),
[(cg.std_string, "client_info"), (cg.std_string, "client_address")],
config[CONF_ON_CLIENT_CONNECTED],
)
if CONF_ON_CLIENT_DISCONNECTED in config:
await automation.build_automation(
var.get_client_disconnected_trigger(),
[(cg.std_string, "client_info"), (cg.std_string, "client_address")],
config[CONF_ON_CLIENT_DISCONNECTED],
)
if encryption_config := config.get(CONF_ENCRYPTION): if encryption_config := config.get(CONF_ENCRYPTION):
decoded = base64.b64decode(encryption_config[CONF_KEY]) decoded = base64.b64decode(encryption_config[CONF_KEY])
cg.add(var.set_noise_psk(list(decoded))) cg.add(var.set_noise_psk(list(decoded)))

View file

@ -31,9 +31,9 @@ APIConnection::APIConnection(std::unique_ptr<socket::Socket> sock, APIServer *pa
this->proto_write_buffer_.reserve(64); this->proto_write_buffer_.reserve(64);
#if defined(USE_API_PLAINTEXT) #if defined(USE_API_PLAINTEXT)
helper_ = std::unique_ptr<APIFrameHelper>{new APIPlaintextFrameHelper(std::move(sock))}; this->helper_ = std::unique_ptr<APIFrameHelper>{new APIPlaintextFrameHelper(std::move(sock))};
#elif defined(USE_API_NOISE) #elif defined(USE_API_NOISE)
helper_ = std::unique_ptr<APIFrameHelper>{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())}; this->helper_ = std::unique_ptr<APIFrameHelper>{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())};
#else #else
#error "No frame helper defined" #error "No frame helper defined"
#endif #endif
@ -41,14 +41,16 @@ APIConnection::APIConnection(std::unique_ptr<socket::Socket> sock, APIServer *pa
void APIConnection::start() { void APIConnection::start() {
this->last_traffic_ = millis(); this->last_traffic_ = millis();
APIError err = helper_->init(); APIError err = this->helper_->init();
if (err != APIError::OK) { if (err != APIError::OK) {
on_fatal_error(); on_fatal_error();
ESP_LOGW(TAG, "%s: Helper init failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); ESP_LOGW(TAG, "%s: Helper init failed: %s errno=%d", this->client_combined_info_.c_str(), api_error_to_str(err),
errno);
return; return;
} }
client_info_ = helper_->getpeername(); this->client_info_ = helper_->getpeername();
helper_->set_log_info(client_info_); this->client_peername_ = this->client_info_;
this->helper_->set_log_info(this->client_info_);
} }
APIConnection::~APIConnection() { APIConnection::~APIConnection() {
@ -67,7 +69,7 @@ void APIConnection::loop() {
// when network is disconnected force disconnect immediately // when network is disconnected force disconnect immediately
// don't wait for timeout // don't wait for timeout
this->on_fatal_error(); this->on_fatal_error();
ESP_LOGW(TAG, "%s: Network unavailable, disconnecting", client_info_.c_str()); ESP_LOGW(TAG, "%s: Network unavailable, disconnecting", this->client_combined_info_.c_str());
return; return;
} }
if (this->next_close_) { if (this->next_close_) {
@ -77,24 +79,26 @@ void APIConnection::loop() {
return; return;
} }
APIError err = helper_->loop(); APIError err = this->helper_->loop();
if (err != APIError::OK) { if (err != APIError::OK) {
on_fatal_error(); on_fatal_error();
ESP_LOGW(TAG, "%s: Socket operation failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); ESP_LOGW(TAG, "%s: Socket operation failed: %s errno=%d", this->client_combined_info_.c_str(),
api_error_to_str(err), errno);
return; return;
} }
ReadPacketBuffer buffer; ReadPacketBuffer buffer;
err = helper_->read_packet(&buffer); err = this->helper_->read_packet(&buffer);
if (err == APIError::WOULD_BLOCK) { if (err == APIError::WOULD_BLOCK) {
// pass // pass
} else if (err != APIError::OK) { } else if (err != APIError::OK) {
on_fatal_error(); on_fatal_error();
if (err == APIError::SOCKET_READ_FAILED && errno == ECONNRESET) { if (err == APIError::SOCKET_READ_FAILED && errno == ECONNRESET) {
ESP_LOGW(TAG, "%s: Connection reset", client_info_.c_str()); ESP_LOGW(TAG, "%s: Connection reset", this->client_combined_info_.c_str());
} else if (err == APIError::CONNECTION_CLOSED) { } else if (err == APIError::CONNECTION_CLOSED) {
ESP_LOGW(TAG, "%s: Connection closed", client_info_.c_str()); ESP_LOGW(TAG, "%s: Connection closed", this->client_combined_info_.c_str());
} else { } else {
ESP_LOGW(TAG, "%s: Reading failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); ESP_LOGW(TAG, "%s: Reading failed: %s errno=%d", this->client_combined_info_.c_str(), api_error_to_str(err),
errno);
} }
return; return;
} else { } else {
@ -114,7 +118,7 @@ void APIConnection::loop() {
// Disconnect if not responded within 2.5*keepalive // Disconnect if not responded within 2.5*keepalive
if (now - this->last_traffic_ > (keepalive * 5) / 2) { if (now - this->last_traffic_ > (keepalive * 5) / 2) {
on_fatal_error(); on_fatal_error();
ESP_LOGW(TAG, "%s didn't respond to ping request in time. Disconnecting...", this->client_info_.c_str()); ESP_LOGW(TAG, "%s didn't respond to ping request in time. Disconnecting...", this->client_combined_info_.c_str());
} }
} else if (now - this->last_traffic_ > keepalive) { } else if (now - this->last_traffic_ > keepalive) {
ESP_LOGVV(TAG, "Sending keepalive PING..."); ESP_LOGVV(TAG, "Sending keepalive PING...");
@ -168,7 +172,7 @@ DisconnectResponse APIConnection::disconnect(const DisconnectRequest &msg) {
// remote initiated disconnect_client // remote initiated disconnect_client
// don't close yet, we still need to send the disconnect response // don't close yet, we still need to send the disconnect response
// close will happen on next loop // close will happen on next loop
ESP_LOGD(TAG, "%s requested disconnected", client_info_.c_str()); ESP_LOGD(TAG, "%s requested disconnected", this->client_combined_info_.c_str());
this->next_close_ = true; this->next_close_ = true;
DisconnectResponse resp; DisconnectResponse resp;
return resp; return resp;
@ -1006,12 +1010,14 @@ bool APIConnection::send_log_message(int level, const char *tag, const char *lin
} }
HelloResponse APIConnection::hello(const HelloRequest &msg) { HelloResponse APIConnection::hello(const HelloRequest &msg) {
this->client_info_ = msg.client_info + " (" + this->helper_->getpeername() + ")"; this->client_info_ = msg.client_info;
this->helper_->set_log_info(client_info_); this->client_peername_ = this->helper_->getpeername();
this->client_combined_info_ = this->client_info_ + " (" + this->client_peername_ + ")";
this->helper_->set_log_info(this->client_combined_info_);
this->client_api_version_major_ = msg.api_version_major; this->client_api_version_major_ = msg.api_version_major;
this->client_api_version_minor_ = msg.api_version_minor; this->client_api_version_minor_ = msg.api_version_minor;
ESP_LOGV(TAG, "Hello from client: '%s' | API Version %" PRIu32 ".%" PRIu32, this->client_info_.c_str(), ESP_LOGV(TAG, "Hello from client: '%s' | %s | API Version %" PRIu32 ".%" PRIu32, this->client_info_.c_str(),
this->client_api_version_major_, this->client_api_version_minor_); this->client_peername_.c_str(), this->client_api_version_major_, this->client_api_version_minor_);
HelloResponse resp; HelloResponse resp;
resp.api_version_major = 1; resp.api_version_major = 1;
@ -1029,9 +1035,9 @@ ConnectResponse APIConnection::connect(const ConnectRequest &msg) {
// bool invalid_password = 1; // bool invalid_password = 1;
resp.invalid_password = !correct; resp.invalid_password = !correct;
if (correct) { if (correct) {
ESP_LOGD(TAG, "%s: Connected successfully", this->client_info_.c_str()); ESP_LOGD(TAG, "%s: Connected successfully", this->client_combined_info_.c_str());
this->connection_state_ = ConnectionState::AUTHENTICATED; this->connection_state_ = ConnectionState::AUTHENTICATED;
this->parent_->get_client_connected_trigger()->trigger(this->client_info_, this->client_peername_);
#ifdef USE_HOMEASSISTANT_TIME #ifdef USE_HOMEASSISTANT_TIME
if (homeassistant::global_homeassistant_time != nullptr) { if (homeassistant::global_homeassistant_time != nullptr) {
this->send_time_request(); this->send_time_request();
@ -1105,10 +1111,11 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type)
return false; return false;
if (!this->helper_->can_write_without_blocking()) { if (!this->helper_->can_write_without_blocking()) {
delay(0); delay(0);
APIError err = helper_->loop(); APIError err = this->helper_->loop();
if (err != APIError::OK) { if (err != APIError::OK) {
on_fatal_error(); on_fatal_error();
ESP_LOGW(TAG, "%s: Socket operation failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); ESP_LOGW(TAG, "%s: Socket operation failed: %s errno=%d", this->client_combined_info_.c_str(),
api_error_to_str(err), errno);
return false; return false;
} }
if (!this->helper_->can_write_without_blocking()) { if (!this->helper_->can_write_without_blocking()) {
@ -1127,9 +1134,10 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type)
if (err != APIError::OK) { if (err != APIError::OK) {
on_fatal_error(); on_fatal_error();
if (err == APIError::SOCKET_WRITE_FAILED && errno == ECONNRESET) { if (err == APIError::SOCKET_WRITE_FAILED && errno == ECONNRESET) {
ESP_LOGW(TAG, "%s: Connection reset", client_info_.c_str()); ESP_LOGW(TAG, "%s: Connection reset", this->client_combined_info_.c_str());
} else { } else {
ESP_LOGW(TAG, "%s: Packet write failed %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno); ESP_LOGW(TAG, "%s: Packet write failed %s errno=%d", this->client_combined_info_.c_str(), api_error_to_str(err),
errno);
} }
return false; return false;
} }
@ -1138,11 +1146,11 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type)
} }
void APIConnection::on_unauthenticated_access() { void APIConnection::on_unauthenticated_access() {
this->on_fatal_error(); this->on_fatal_error();
ESP_LOGD(TAG, "%s: tried to access without authentication.", this->client_info_.c_str()); ESP_LOGD(TAG, "%s: tried to access without authentication.", this->client_combined_info_.c_str());
} }
void APIConnection::on_no_setup_connection() { void APIConnection::on_no_setup_connection() {
this->on_fatal_error(); this->on_fatal_error();
ESP_LOGD(TAG, "%s: tried to access without full connection.", this->client_info_.c_str()); ESP_LOGD(TAG, "%s: tried to access without full connection.", this->client_combined_info_.c_str());
} }
void APIConnection::on_fatal_error() { void APIConnection::on_fatal_error() {
this->helper_->close(); this->helper_->close();

View file

@ -202,6 +202,8 @@ class APIConnection : public APIServerConnection {
std::unique_ptr<APIFrameHelper> helper_; std::unique_ptr<APIFrameHelper> helper_;
std::string client_info_; std::string client_info_;
std::string client_peername_;
std::string client_combined_info_;
uint32_t client_api_version_major_{0}; uint32_t client_api_version_major_{0};
uint32_t client_api_version_minor_{0}; uint32_t client_api_version_minor_{0};
#ifdef USE_ESP32_CAMERA #ifdef USE_ESP32_CAMERA

View file

@ -111,6 +111,7 @@ void APIServer::loop() {
[](const std::unique_ptr<APIConnection> &conn) { return !conn->remove_; }); [](const std::unique_ptr<APIConnection> &conn) { return !conn->remove_; });
// print disconnection messages // print disconnection messages
for (auto it = new_end; it != this->clients_.end(); ++it) { for (auto it = new_end; it != this->clients_.end(); ++it) {
this->client_disconnected_trigger_->trigger((*it)->client_info_, (*it)->client_peername_);
ESP_LOGV(TAG, "Removing connection to %s", (*it)->client_info_.c_str()); ESP_LOGV(TAG, "Removing connection to %s", (*it)->client_info_.c_str());
} }
// resize vector // resize vector

View file

@ -4,6 +4,7 @@
#include "api_pb2.h" #include "api_pb2.h"
#include "api_pb2_service.h" #include "api_pb2_service.h"
#include "esphome/components/socket/socket.h" #include "esphome/components/socket/socket.h"
#include "esphome/core/automation.h"
#include "esphome/core/component.h" #include "esphome/core/component.h"
#include "esphome/core/controller.h" #include "esphome/core/controller.h"
#include "esphome/core/defines.h" #include "esphome/core/defines.h"
@ -103,6 +104,11 @@ class APIServer : public Component, public Controller {
const std::vector<HomeAssistantStateSubscription> &get_state_subs() const; const std::vector<HomeAssistantStateSubscription> &get_state_subs() const;
const std::vector<UserServiceDescriptor *> &get_user_services() const { return this->user_services_; } const std::vector<UserServiceDescriptor *> &get_user_services() const { return this->user_services_; }
Trigger<std::string, std::string> *get_client_connected_trigger() const { return this->client_connected_trigger_; }
Trigger<std::string, std::string> *get_client_disconnected_trigger() const {
return this->client_disconnected_trigger_;
}
protected: protected:
std::unique_ptr<socket::Socket> socket_ = nullptr; std::unique_ptr<socket::Socket> socket_ = nullptr;
uint16_t port_{6053}; uint16_t port_{6053};
@ -112,6 +118,8 @@ class APIServer : public Component, public Controller {
std::string password_; std::string password_;
std::vector<HomeAssistantStateSubscription> state_subs_; std::vector<HomeAssistantStateSubscription> state_subs_;
std::vector<UserServiceDescriptor *> user_services_; std::vector<UserServiceDescriptor *> user_services_;
Trigger<std::string, std::string> *client_connected_trigger_ = new Trigger<std::string, std::string>();
Trigger<std::string, std::string> *client_disconnected_trigger_ = new Trigger<std::string, std::string>();
#ifdef USE_API_NOISE #ifdef USE_API_NOISE
std::shared_ptr<APINoiseContext> noise_ctx_ = std::make_shared<APINoiseContext>(); std::shared_ptr<APINoiseContext> noise_ctx_ = std::make_shared<APINoiseContext>();