mirror of
https://github.com/esphome/esphome.git
synced 2024-11-14 02:58:11 +01:00
Socket refactor and SSL
This commit is contained in:
parent
ea4a458214
commit
40dd9c5dce
15 changed files with 877 additions and 161 deletions
|
@ -192,6 +192,11 @@ class APIClient(threading.Thread):
|
|||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket.settimeout(10.0)
|
||||
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
import ssl
|
||||
|
||||
context = ssl.SSLContext()
|
||||
self._socket = context.wrap_socket(self._socket)
|
||||
|
||||
try:
|
||||
self._socket.connect((ip, self._port))
|
||||
except OSError as err:
|
||||
|
|
|
@ -19,7 +19,7 @@ from esphome.const import (
|
|||
from esphome.core import coroutine_with_priority
|
||||
|
||||
DEPENDENCIES = ["network"]
|
||||
AUTO_LOAD = ["async_tcp"]
|
||||
AUTO_LOAD = ["socket", "ssl"]
|
||||
CODEOWNERS = ["@OttoWinter"]
|
||||
|
||||
api_ns = cg.esphome_ns.namespace("api")
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "esphome/core/log.h"
|
||||
#include "esphome/core/util.h"
|
||||
#include "esphome/core/version.h"
|
||||
#include <errno.h>
|
||||
|
||||
#ifdef USE_DEEP_SLEEP
|
||||
#include "esphome/components/deep_sleep/deep_sleep_component.h"
|
||||
|
@ -18,42 +19,62 @@ namespace api {
|
|||
|
||||
static const char *const TAG = "api.connection";
|
||||
|
||||
APIConnection::APIConnection(AsyncClient *client, APIServer *parent)
|
||||
: client_(client), parent_(parent), initial_state_iterator_(parent, this), list_entities_iterator_(parent, this) {
|
||||
APIConnection::APIConnection(std::unique_ptr<socket::Socket> sock, APIServer *parent)
|
||||
: socket_(std::move(sock)), parent_(parent),
|
||||
initial_state_iterator_(parent, this), list_entities_iterator_(parent, this) {
|
||||
this->proto_write_buffer_.reserve(64);
|
||||
this->recv_buffer_.reserve(32);
|
||||
}
|
||||
void APIConnection::start() {
|
||||
/*
|
||||
this->client_->onError([](void *s, AsyncClient *c, int8_t error) { ((APIConnection *) s)->on_error_(error); }, this);
|
||||
this->client_->onDisconnect([](void *s, AsyncClient *c) { ((APIConnection *) s)->on_disconnect_(); }, this);
|
||||
this->client_->onTimeout([](void *s, AsyncClient *c, uint32_t time) { ((APIConnection *) s)->on_timeout_(time); },
|
||||
this);
|
||||
this->client_->onData([](void *s, AsyncClient *c, void *buf,
|
||||
size_t len) { ((APIConnection *) s)->on_data_(reinterpret_cast<uint8_t *>(buf), len); },
|
||||
this);
|
||||
|
||||
this->send_buffer_.reserve(64);
|
||||
this->recv_buffer_.reserve(32);
|
||||
this->client_info_ = this->client_->remoteIP().toString().c_str();
|
||||
this);*/
|
||||
/*this->client_info_ = this->client_->remoteIP().toString().c_str();*/
|
||||
this->last_traffic_ = millis();
|
||||
|
||||
int err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket could not enable non-blocking, errno: %d", errno);
|
||||
remove_ = true;
|
||||
return;
|
||||
}
|
||||
int enable = 1;
|
||||
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
|
||||
remove_ = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
APIConnection::~APIConnection() { delete this->client_; }
|
||||
void APIConnection::on_error_(int8_t error) { this->remove_ = true; }
|
||||
|
||||
APIConnection::~APIConnection() { this->socket_ = nullptr; }
|
||||
/*void APIConnection::on_error_(int8_t error) { this->remove_ = true; }
|
||||
void APIConnection::on_disconnect_() { this->remove_ = true; }
|
||||
void APIConnection::on_timeout_(uint32_t time) { this->on_fatal_error(); }
|
||||
void APIConnection::on_data_(uint8_t *buf, size_t len) {
|
||||
if (len == 0 || buf == nullptr)
|
||||
return;
|
||||
this->recv_buffer_.insert(this->recv_buffer_.end(), buf, buf + len);
|
||||
}
|
||||
}*/
|
||||
|
||||
/// Returns at estimate of how many bytes will need to be read for this message
|
||||
void APIConnection::parse_recv_buffer_() {
|
||||
if (this->recv_buffer_.empty() || this->remove_)
|
||||
if (this->recv_len_ == 0 || this->remove_)
|
||||
return;
|
||||
|
||||
while (!this->recv_buffer_.empty()) {
|
||||
while (this->recv_len_ != 0) {
|
||||
if (this->recv_buffer_[0] != 0x00) {
|
||||
ESP_LOGW(TAG, "Invalid preamble from %s", this->client_info_.c_str());
|
||||
this->on_fatal_error();
|
||||
ESP_LOGW(TAG, "Invalid preamble from %s", this->client_info_.c_str());
|
||||
return;
|
||||
}
|
||||
uint32_t i = 1;
|
||||
const uint32_t size = this->recv_buffer_.size();
|
||||
const uint32_t size = recv_len_;
|
||||
uint32_t consumed;
|
||||
auto msg_size_varint = ProtoVarInt::parse(&this->recv_buffer_[i], size - i, &consumed);
|
||||
if (!msg_size_varint.has_value())
|
||||
|
@ -77,15 +98,22 @@ void APIConnection::parse_recv_buffer_() {
|
|||
this->read_message(msg_size, msg_type, msg);
|
||||
if (this->remove_)
|
||||
return;
|
||||
// pop front
|
||||
uint32_t total = i + msg_size;
|
||||
this->recv_buffer_.erase(this->recv_buffer_.begin(), this->recv_buffer_.begin() + total);
|
||||
uint32_t total_size = i + msg_size;
|
||||
// left-rotate remaining data (if any) to beginning
|
||||
// inefficient, but at the moment this only receives small packets anyway
|
||||
std::copy(
|
||||
this->recv_buffer_.begin() + total_size,
|
||||
this->recv_buffer_.begin() + recv_len_,
|
||||
this->recv_buffer_.begin()
|
||||
);
|
||||
this->recv_len_ -= total_size;
|
||||
|
||||
this->last_traffic_ = millis();
|
||||
}
|
||||
}
|
||||
|
||||
void APIConnection::disconnect_client() {
|
||||
this->client_->close();
|
||||
this->socket_->close();
|
||||
this->remove_ = true;
|
||||
}
|
||||
|
||||
|
@ -97,17 +125,49 @@ void APIConnection::loop() {
|
|||
this->disconnect_client();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!network_is_connected()) {
|
||||
// when network is disconnected force disconnect immediately
|
||||
// don't wait for timeout
|
||||
this->on_fatal_error();
|
||||
return;
|
||||
}
|
||||
if (this->client_->disconnected()) {
|
||||
// failsafe for disconnect logic
|
||||
this->on_disconnect_();
|
||||
return;
|
||||
|
||||
this->try_send_pending_data_();
|
||||
while (!this->remove_) {
|
||||
// Note: vector.capacity is not used, as there's no good way to insert
|
||||
// data into it without zero-initialising on resize
|
||||
// https://stackoverflow.com/a/7689457
|
||||
size_t capacity = this->recv_buffer_.size();
|
||||
size_t used = this->recv_len_;
|
||||
size_t space = capacity - used;
|
||||
uint8_t *head = &this->recv_buffer_[used];
|
||||
|
||||
if (space == 0) {
|
||||
// no space to read, allocate more then retry
|
||||
this->recv_buffer_.resize(capacity + 64);
|
||||
continue;
|
||||
}
|
||||
|
||||
ssize_t received = socket_->read(head, space);
|
||||
if (received == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
// read would block
|
||||
break;
|
||||
}
|
||||
this->on_fatal_error();
|
||||
ESP_LOGW(TAG, "Error reading from socket: errno %d", errno);
|
||||
break;
|
||||
} else if (received == 0) {
|
||||
break;
|
||||
}
|
||||
// ESP_LOGD(TAG, "received %s", hexencode(head, received).c_str());
|
||||
this->recv_len_ += received;
|
||||
|
||||
if (received != space)
|
||||
// done with reading
|
||||
break;
|
||||
|
||||
this->parse_recv_buffer_();
|
||||
}
|
||||
this->parse_recv_buffer_();
|
||||
|
||||
|
@ -115,18 +175,20 @@ void APIConnection::loop() {
|
|||
this->initial_state_iterator_.advance();
|
||||
|
||||
const uint32_t keepalive = 60000;
|
||||
const uint32_t now = millis();
|
||||
if (this->sent_ping_) {
|
||||
// Disconnect if not responded within 2.5*keepalive
|
||||
if (millis() - this->last_traffic_ > (keepalive * 5) / 2) {
|
||||
if (now - this->last_traffic_ > (keepalive * 5) / 2) {
|
||||
ESP_LOGW(TAG, "'%s' didn't respond to ping request in time. Disconnecting...", this->client_info_.c_str());
|
||||
this->disconnect_client();
|
||||
}
|
||||
} else if (millis() - this->last_traffic_ > keepalive) {
|
||||
} else if (now - this->last_traffic_ > keepalive) {
|
||||
this->sent_ping_ = true;
|
||||
this->send_ping_request(PingRequest());
|
||||
}
|
||||
|
||||
#ifdef USE_ESP32_CAMERA
|
||||
// FIXME
|
||||
if (this->image_reader_.available()) {
|
||||
uint32_t space = this->client_->space();
|
||||
// reserve 15 bytes for metadata, and at least 64 bytes of data
|
||||
|
@ -152,6 +214,30 @@ void APIConnection::loop() {
|
|||
}
|
||||
#endif
|
||||
}
|
||||
void APIConnection::try_send_pending_data_() {
|
||||
if (this->pending_send_buffer_.empty() || this->remove_)
|
||||
return;
|
||||
const uint8_t *data = &this->pending_send_buffer_[0];
|
||||
const size_t len = this->pending_send_buffer_.size();
|
||||
ssize_t written = this->socket_->write(data, len);
|
||||
if (written == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
// write would block
|
||||
return;
|
||||
}
|
||||
this->on_fatal_error();
|
||||
ESP_LOGW(TAG, "Error writing to socket: errno %d", errno);
|
||||
return;
|
||||
} else if (written == len) {
|
||||
this->pending_send_buffer_.clear();
|
||||
} else {
|
||||
// FIXME: inefficient
|
||||
this->pending_send_buffer_.erase(
|
||||
this->pending_send_buffer_.begin(),
|
||||
this->pending_send_buffer_.begin() + written
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_default_unique_id(const std::string &component_type, Nameable *nameable) {
|
||||
return App.get_name() + component_type + nameable->get_object_id();
|
||||
|
@ -703,8 +789,7 @@ bool APIConnection::send_log_message(int level, const char *tag, const char *lin
|
|||
}
|
||||
|
||||
HelloResponse APIConnection::hello(const HelloRequest &msg) {
|
||||
this->client_info_ = msg.client_info + " (" + this->client_->remoteIP().toString().c_str();
|
||||
this->client_info_ += ")";
|
||||
this->client_info_ = msg.client_info + " (" + this->socket_->getpeername() + ")";
|
||||
ESP_LOGV(TAG, "Hello from client: '%s'", this->client_info_.c_str());
|
||||
|
||||
HelloResponse resp;
|
||||
|
@ -779,9 +864,47 @@ void APIConnection::subscribe_home_assistant_states(const SubscribeHomeAssistant
|
|||
}
|
||||
}
|
||||
}
|
||||
bool APIConnection::send_(const void *buf, size_t len, bool force) {
|
||||
if (this->remove_)
|
||||
return false;
|
||||
if (len == 0)
|
||||
return true;
|
||||
|
||||
// ESP_LOGD(TAG, "writing %s", hexencode((const uint8_t *) buf, len).c_str());
|
||||
ssize_t written = this->socket_->write(buf, len);
|
||||
bool add_to_pending = false;
|
||||
if (written == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
|
||||
// write would block
|
||||
add_to_pending = force;
|
||||
} else if (written == -1) {
|
||||
this->on_fatal_error();
|
||||
ESP_LOGW(TAG, "Error writing to socket: errno %d", errno);
|
||||
return false;
|
||||
} else if (written == len) {
|
||||
// all written
|
||||
return true;
|
||||
} else if (written == 0) {
|
||||
add_to_pending = force;
|
||||
} else {
|
||||
// partially written, must insert in pending data
|
||||
add_to_pending = true;
|
||||
}
|
||||
if (add_to_pending) {
|
||||
this->pending_send_buffer_.insert(
|
||||
this->pending_send_buffer_.end(),
|
||||
reinterpret_cast<const uint8_t *>(buf) + written,
|
||||
reinterpret_cast<const uint8_t *>(buf) + len
|
||||
);
|
||||
}
|
||||
return add_to_pending;
|
||||
}
|
||||
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) {
|
||||
if (this->remove_)
|
||||
return false;
|
||||
this->try_send_pending_data_();
|
||||
if (!this->pending_send_buffer_.empty())
|
||||
// FIXME: still send important message like HA service
|
||||
return false;
|
||||
|
||||
std::vector<uint8_t> header;
|
||||
header.push_back(0x00);
|
||||
|
@ -790,36 +913,25 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type)
|
|||
|
||||
size_t needed_space = buffer.get_buffer()->size() + header.size();
|
||||
|
||||
if (needed_space > this->client_->space()) {
|
||||
delay(0);
|
||||
if (needed_space > this->client_->space()) {
|
||||
// SubscribeLogsResponse
|
||||
if (message_type != 29) {
|
||||
ESP_LOGV(TAG, "Cannot send message because of TCP buffer space");
|
||||
}
|
||||
delay(0);
|
||||
return false;
|
||||
}
|
||||
if (!this->send_(header.data(), header.size(), false)) {
|
||||
// nothing written and doesn't fit
|
||||
return false;
|
||||
}
|
||||
|
||||
this->client_->add(reinterpret_cast<char *>(header.data()), header.size(),
|
||||
ASYNC_WRITE_FLAG_COPY | ASYNC_WRITE_FLAG_MORE);
|
||||
this->client_->add(reinterpret_cast<char *>(buffer.get_buffer()->data()), buffer.get_buffer()->size(),
|
||||
ASYNC_WRITE_FLAG_COPY);
|
||||
bool ret = this->client_->send();
|
||||
return ret;
|
||||
// force send because we already sent the header
|
||||
this->send_(buffer.get_buffer()->data(), buffer.get_buffer()->size(), true);
|
||||
return true;
|
||||
}
|
||||
void APIConnection::on_unauthenticated_access() {
|
||||
ESP_LOGD(TAG, "'%s' tried to access without authentication.", this->client_info_.c_str());
|
||||
this->on_fatal_error();
|
||||
ESP_LOGD(TAG, "'%s' tried to access without authentication.", this->client_info_.c_str());
|
||||
}
|
||||
void APIConnection::on_no_setup_connection() {
|
||||
ESP_LOGD(TAG, "'%s' tried to access without full connection.", this->client_info_.c_str());
|
||||
this->on_fatal_error();
|
||||
ESP_LOGD(TAG, "'%s' tried to access without full connection.", this->client_info_.c_str());
|
||||
}
|
||||
void APIConnection::on_fatal_error() {
|
||||
ESP_LOGV(TAG, "Error: Disconnecting %s", this->client_info_.c_str());
|
||||
this->client_->close();
|
||||
this->socket_->close();
|
||||
this->remove_ = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -11,9 +11,10 @@ namespace api {
|
|||
|
||||
class APIConnection : public APIServerConnection {
|
||||
public:
|
||||
APIConnection(AsyncClient *client, APIServer *parent);
|
||||
APIConnection(std::unique_ptr<socket::Socket> socket, APIServer *parent);
|
||||
virtual ~APIConnection();
|
||||
|
||||
void start();
|
||||
void disconnect_client();
|
||||
void loop();
|
||||
|
||||
|
@ -135,19 +136,18 @@ class APIConnection : public APIServerConnection {
|
|||
void on_unauthenticated_access() override;
|
||||
void on_no_setup_connection() override;
|
||||
ProtoWriteBuffer create_buffer() override {
|
||||
this->send_buffer_.clear();
|
||||
return {&this->send_buffer_};
|
||||
// FIXME: ensure no recursive writes can happen
|
||||
this->proto_write_buffer_.clear();
|
||||
return {&this->proto_write_buffer_};
|
||||
}
|
||||
bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override;
|
||||
|
||||
protected:
|
||||
friend APIServer;
|
||||
|
||||
void on_error_(int8_t error);
|
||||
void on_disconnect_();
|
||||
void on_timeout_(uint32_t time);
|
||||
void on_data_(uint8_t *buf, size_t len);
|
||||
void parse_recv_buffer_();
|
||||
bool send_(const void *buf, size_t len, bool force);
|
||||
void try_send_pending_data_();
|
||||
|
||||
enum class ConnectionState {
|
||||
WAITING_FOR_HELLO,
|
||||
|
@ -157,8 +157,14 @@ class APIConnection : public APIServerConnection {
|
|||
|
||||
bool remove_{false};
|
||||
|
||||
std::vector<uint8_t> send_buffer_;
|
||||
// Buffer used to encode proto messages
|
||||
// Re-use to prevent allocations
|
||||
std::vector<uint8_t> proto_write_buffer_;
|
||||
// Buffer containing pending sends
|
||||
std::vector<uint8_t> pending_send_buffer_;
|
||||
// Buffer containing data that was received but not parsed yet
|
||||
std::vector<uint8_t> recv_buffer_;
|
||||
size_t recv_len_{0};
|
||||
|
||||
std::string client_info_;
|
||||
#ifdef USE_ESP32_CAMERA
|
||||
|
@ -170,9 +176,8 @@ class APIConnection : public APIServerConnection {
|
|||
uint32_t last_traffic_;
|
||||
bool sent_ping_{false};
|
||||
bool service_call_subscription_{false};
|
||||
bool current_nodelay_{false};
|
||||
bool next_close_{false};
|
||||
AsyncClient *client_;
|
||||
std::unique_ptr<socket::Socket> socket_;
|
||||
APIServer *parent_;
|
||||
InitialStateIterator initial_state_iterator_;
|
||||
ListEntitiesIterator list_entities_iterator_;
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
#include "esphome/core/util.h"
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/core/version.h"
|
||||
#include <errno.h>
|
||||
//#include <arpa/inet.h>
|
||||
|
||||
#ifdef USE_LOGGER
|
||||
#include "esphome/components/logger/logger.h"
|
||||
|
@ -21,20 +23,115 @@ static const char *const TAG = "api";
|
|||
void APIServer::setup() {
|
||||
ESP_LOGCONFIG(TAG, "Setting up Home Assistant API server...");
|
||||
this->setup_controller();
|
||||
this->server_ = AsyncServer(this->port_);
|
||||
this->server_.setNoDelay(false);
|
||||
this->server_.begin();
|
||||
this->server_.onClient(
|
||||
[](void *s, AsyncClient *client) {
|
||||
if (client == nullptr)
|
||||
return;
|
||||
socket_ = socket::socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (socket_ == nullptr) {
|
||||
ESP_LOGW(TAG, "Could not create socket.");
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
int enable = 1;
|
||||
int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
|
||||
// we can still continue
|
||||
}
|
||||
err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
/*struct sockaddr_storage dest_addr;
|
||||
memset(&dest_addr, 0, sizeof(dest_addr));
|
||||
struct sockaddr_in *dest_addr_ip4 = (struct sockaddr_in *) &dest_addr;
|
||||
dest_addr_ip4->sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
dest_addr_ip4->sin_family = AF_INET;
|
||||
dest_addr_ip4->sin_port = htons(this->port_);
|
||||
|
||||
err = socket_->bind((struct sockaddr *) &dest_addr, sizeof(dest_addr));*/
|
||||
|
||||
struct sockaddr_in server;
|
||||
memset(&server, 0, sizeof(server));
|
||||
server.sin_family = AF_INET;
|
||||
server.sin_addr.s_addr = INADDR_ANY;
|
||||
server.sin_port = htons(this->port_);
|
||||
|
||||
err = socket_->bind((struct sockaddr *) &server, sizeof(server));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
err = socket_->listen(4);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
ssl_ = ssl::create_context();
|
||||
if (!ssl_) {
|
||||
ESP_LOGW(TAG, "Failed to create SSL context: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
ssl_->set_server_certificate(R"(-----BEGIN CERTIFICATE-----
|
||||
MIIDETCCAfkCFGNtbm6nA3CZM7no7HqdWikhUMSkMA0GCSqGSIb3DQEBCwUAMEUx
|
||||
CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
|
||||
cm5ldCBXaWRnaXRzIFB0eSBMdGQwHhcNMjEwODA5MTgwOTMyWhcNMjEwOTA4MTgw
|
||||
OTMyWjBFMQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UE
|
||||
CgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOC
|
||||
AQ8AMIIBCgKCAQEAwbt/qjWftqZtdRaJ5QjRf/8Sh6JT8KN4Bu9cGbHJIKAQLhy6
|
||||
8/qdB24Ar8SyuKEaV8HRcCguTQ58jdK5rbaQu/Zpppgy9lF3AHH1MhVHavGNca3A
|
||||
ejFtJr4DuTLkv/HjpgcAHjhZk+mFeNXrHeFrPIzF3imSyV1xyqoBxpa1cCFH/D3J
|
||||
o2S6PMdAEcHSoaP5TEuM9e2j9Sc97LughMaFkR1R4cz2kEyMZIOASHkFCJMV6pjg
|
||||
PVOqxu11oFYJn9/zh1Ea6PChYq+bGBmj60vwh+tpA6E8T0PzkxUuVklAD5pBfXoD
|
||||
y8xW8ulc0CPSGSaxbn2vudUBJyZFvQBTQhFVcwIDAQABMA0GCSqGSIb3DQEBCwUA
|
||||
A4IBAQBQM80osk+ryQ+CqBhyOLQBOeQkmCVNzMUjVBG7tP4vkuJtfAdyUuKBWGtr
|
||||
X2VrkL0yueeDt9rdib5QbXWih4sT7KdQlnSBmnrac0MM7wCh+lhCnJhWWCUBHP9s
|
||||
8rkL2XOrISbVi80wqpJn0y4FMaoK6KnxyelallHuNZ+3EZAuZrhGAkV68Z83CIFO
|
||||
5emvAIGq73U/lddLDV6sz7zWeDdnyfTpkLzml8wJLO9Ob7o6aw7WJK/edjYdc2XW
|
||||
pIMatEESaN9MlWI5SXQS4AcMnqdUqab5587cHDrgVjBd8RmvdyT9j2v7nS2JyEK0
|
||||
DkASogRqBmCLR1/0UW+dFARCZI9k
|
||||
-----END CERTIFICATE-----
|
||||
)");
|
||||
ssl_->set_private_key(R"(-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEAwbt/qjWftqZtdRaJ5QjRf/8Sh6JT8KN4Bu9cGbHJIKAQLhy6
|
||||
8/qdB24Ar8SyuKEaV8HRcCguTQ58jdK5rbaQu/Zpppgy9lF3AHH1MhVHavGNca3A
|
||||
ejFtJr4DuTLkv/HjpgcAHjhZk+mFeNXrHeFrPIzF3imSyV1xyqoBxpa1cCFH/D3J
|
||||
o2S6PMdAEcHSoaP5TEuM9e2j9Sc97LughMaFkR1R4cz2kEyMZIOASHkFCJMV6pjg
|
||||
PVOqxu11oFYJn9/zh1Ea6PChYq+bGBmj60vwh+tpA6E8T0PzkxUuVklAD5pBfXoD
|
||||
y8xW8ulc0CPSGSaxbn2vudUBJyZFvQBTQhFVcwIDAQABAoIBAGVpQuDUhTBVWkLK
|
||||
c5CC1zfLS+XYIVx8FZ57uZhxqjj70LxyqaKBc6Wp/Y4ExxFCs8lwWbP+NI59oNGU
|
||||
l0HJqWXbDV75mOO7rTF8db+rx+DBZSs2quTL7rkzCjvt2jRn6KTGUVeAY9O7j/S6
|
||||
9gKEN2BQyFsNJBtoYOKXr6pGxd9Vg3K0j7DJXf5uK7lWIrxtU9k7QgMJFdnhbzEu
|
||||
0TnhFEVMDdBIm+yrTdL8lmIdT8DOUIJgyTJ1iICFndksPQSgBAWQaGKaaxZbn0c3
|
||||
Oy778VFqT2HywHVbJQL4XBe/yYUhjbpF1Hv9EEbK3Rm04xsCDbZru6/AK88gHBk4
|
||||
b7uUSwECgYEA4BmML1isP9h8zqAvCEFjFmAWpoLBBZ+5I9Go1PIWhlCYY+G7AUXw
|
||||
zxb0J6d9UGsYTkJXlgE77+HBzqlgyhCkngNuAAPm37ebdwuy5iBr32c9RLahR5W5
|
||||
Nh+J3le9JTXe9B9uwfggD06dBFmhgG0PQdyBr4Daa3a8VRJAD1MGYMECgYEA3U9U
|
||||
QwxQOYBkdJTbIQnTP7vnFuhWn9V5BMn5PczJSwGJEgaHgIL5Bm5NHa/ON3UX6QIi
|
||||
uk73fGfohN8Ii1MjVKNFKM/LZ30XSufVHrm7yH6xRR4qbZUk4KhKxV/uOVluv38P
|
||||
bis9B9cye3ETnjDhkWK4/TJeTHHlTAKMQuOQzzMCgYEAmtlsYYbvNwq7aveKqDSu
|
||||
aFarMBGnmOA+SP7ln4dMgzELq/DdjEqs1BwzR3dXgwsNd34mEVP2+5HOnqOxas7H
|
||||
QRxzlPUdQjcX6NGfo56Bi5RF5MYheVp+6WQvmwCbhSvNTHivyr5OQOV8X/YjP5+c
|
||||
bFEXF5N82cbo6gu7Uht3i8ECgYAh511JSEGiDYFWOte3IAI06VxlrgJXSiTYDvkX
|
||||
9p9/1iRhlo57qZTs30kBG0XESTP4hlM7p41SibidYm20qm/nL3wQ3ISUvh0rZIjJ
|
||||
xDp4ZLBTnmNxlj+oCyApTKD6ODE3NQfwIL+gy973+kK/IU3tL+qXH3hCzdAK7Pj/
|
||||
5kzw8QKBgQDUYQGH1JT93Yn9uIyfX1v6HcB1azDbF16JEOFZoGlS1gxFCobIb7jA
|
||||
2/Y0HfFUUfDGexjQNReFi0IXjgBvYmJX7rF9tGsTdXh35Lu2cTd0DcykGPVcFyJW
|
||||
PSf0vGzbAqpdriYQStaed+HgTdW6kHsOBNeJbbJkjsQpoaoWX3tEDw==
|
||||
-----END RSA PRIVATE KEY-----
|
||||
)");
|
||||
err = ssl_->init();
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Failed to initialize SSL context: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
// can't print here because in lwIP thread
|
||||
// ESP_LOGD(TAG, "New client connected from %s", client->remoteIP().toString().c_str());
|
||||
auto *a_this = (APIServer *) s;
|
||||
a_this->clients_.push_back(new APIConnection(client, a_this));
|
||||
},
|
||||
this);
|
||||
#ifdef USE_LOGGER
|
||||
if (logger::global_logger != nullptr) {
|
||||
logger::global_logger->add_on_log_callback([this](int level, const char *tag, const char *message) {
|
||||
|
@ -59,6 +156,27 @@ void APIServer::setup() {
|
|||
#endif
|
||||
}
|
||||
void APIServer::loop() {
|
||||
// Accept new clients
|
||||
while (true) {
|
||||
struct sockaddr_storage source_addr;
|
||||
socklen_t addr_len = sizeof(source_addr);
|
||||
auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len);
|
||||
if (!sock)
|
||||
break;
|
||||
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
|
||||
|
||||
// wrap socket
|
||||
auto sock2 = ssl_->wrap_socket(std::move(sock));
|
||||
if (!sock2) {
|
||||
ESP_LOGW(TAG, "Failed to wrap socket with SSL: errno %d", errno);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *conn = new APIConnection(std::move(sock2), this);
|
||||
clients_.push_back(conn);
|
||||
conn->start();
|
||||
}
|
||||
|
||||
// Partition clients into remove and active
|
||||
auto new_end =
|
||||
std::partition(this->clients_.begin(), this->clients_.end(), [](APIConnection *conn) { return !conn->remove_; });
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
#include "esphome/core/controller.h"
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/components/socket/socket.h"
|
||||
#include "esphome/components/ssl/ssl_context.h"
|
||||
#include "api_pb2.h"
|
||||
#include "api_pb2_service.h"
|
||||
#include "util.h"
|
||||
|
@ -11,13 +13,6 @@
|
|||
#include "subscribe_state.h"
|
||||
#include "user_services.h"
|
||||
|
||||
#ifdef ARDUINO_ARCH_ESP32
|
||||
#include <AsyncTCP.h>
|
||||
#endif
|
||||
#ifdef ARDUINO_ARCH_ESP8266
|
||||
#include <ESPAsyncTCP.h>
|
||||
#endif
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
|
@ -86,7 +81,8 @@ class APIServer : public Component, public Controller {
|
|||
const std::vector<UserServiceDescriptor *> &get_user_services() const { return this->user_services_; }
|
||||
|
||||
protected:
|
||||
AsyncServer server_{0};
|
||||
std::unique_ptr<socket::Socket> socket_ = nullptr;
|
||||
std::unique_ptr<ssl::SSLContext> ssl_ = nullptr;
|
||||
uint16_t port_{6053};
|
||||
uint32_t reboot_timeout_{300000};
|
||||
uint32_t last_connected_{0};
|
||||
|
|
|
@ -15,6 +15,7 @@ from esphome.core import CORE, coroutine_with_priority
|
|||
|
||||
CODEOWNERS = ["@esphome/core"]
|
||||
DEPENDENCIES = ["network"]
|
||||
AUTO_LOAD = ["socket"]
|
||||
|
||||
CONF_ON_STATE_CHANGE = "on_state_change"
|
||||
CONF_ON_BEGIN = "on_begin"
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#include "esphome/core/log.h"
|
||||
#include "esphome/core/application.h"
|
||||
#include "esphome/core/util.h"
|
||||
|
||||
#include <errno.h>
|
||||
#include <cstdio>
|
||||
#include <MD5Builder.h>
|
||||
#ifdef ARDUINO_ARCH_ESP32
|
||||
|
@ -19,8 +19,44 @@ static const char *const TAG = "ota";
|
|||
static const uint8_t OTA_VERSION_1_0 = 1;
|
||||
|
||||
void OTAComponent::setup() {
|
||||
this->server_ = new WiFiServer(this->port_);
|
||||
this->server_->begin();
|
||||
server_ = socket::socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (server_ == nullptr) {
|
||||
ESP_LOGW(TAG, "Could not create socket.");
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
int enable = 1;
|
||||
int err = server_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
|
||||
// we can still continue
|
||||
}
|
||||
err = server_->setblocking(false);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
struct sockaddr_in server;
|
||||
memset(&server, 0, sizeof(server));
|
||||
server.sin_family = AF_INET;
|
||||
server.sin_addr.s_addr = INADDR_ANY;
|
||||
server.sin_port = htons(this->port_);
|
||||
|
||||
err = server_->bind((struct sockaddr *) &server, sizeof(server));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
err = server_->listen(4);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
this->dump_config();
|
||||
}
|
||||
|
@ -59,23 +95,28 @@ void OTAComponent::handle_() {
|
|||
uint8_t ota_features;
|
||||
(void) ota_features;
|
||||
|
||||
if (!this->client_.connected()) {
|
||||
this->client_ = this->server_->available();
|
||||
if (client_ == nullptr) {
|
||||
struct sockaddr_storage source_addr;
|
||||
socklen_t addr_len = sizeof(source_addr);
|
||||
client_ = server_->accept((struct sockaddr *) &source_addr, &addr_len);
|
||||
}
|
||||
if (client_ == nullptr)
|
||||
return;
|
||||
|
||||
if (!this->client_.connected())
|
||||
return;
|
||||
int enable = 1;
|
||||
int err = client_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
|
||||
return;
|
||||
}
|
||||
|
||||
// enable nodelay for outgoing data
|
||||
this->client_.setNoDelay(true);
|
||||
|
||||
ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_.remoteIP().toString().c_str());
|
||||
ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_->getpeername().c_str());
|
||||
this->status_set_warning();
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(OTA_STARTED, 0.0f, 0);
|
||||
#endif
|
||||
|
||||
if (!this->wait_receive_(buf, 5)) {
|
||||
if (!this->readall_(buf, 5)) {
|
||||
ESP_LOGW(TAG, "Reading magic bytes failed!");
|
||||
goto error;
|
||||
}
|
||||
|
@ -88,11 +129,12 @@ void OTAComponent::handle_() {
|
|||
}
|
||||
|
||||
// Send OK and version - 2 bytes
|
||||
this->client_.write(OTA_RESPONSE_OK);
|
||||
this->client_.write(OTA_VERSION_1_0);
|
||||
buf[0] = OTA_RESPONSE_OK;
|
||||
buf[1] = OTA_VERSION_1_0;
|
||||
this->writeall_(buf, 2);
|
||||
|
||||
// Read features - 1 byte
|
||||
if (!this->wait_receive_(buf, 1)) {
|
||||
if (!this->readall_(buf, 1)) {
|
||||
ESP_LOGW(TAG, "Reading features failed!");
|
||||
goto error;
|
||||
}
|
||||
|
@ -100,10 +142,12 @@ void OTAComponent::handle_() {
|
|||
ESP_LOGV(TAG, "OTA features is 0x%02X", ota_features);
|
||||
|
||||
// Acknowledge header - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_HEADER_OK);
|
||||
buf[0] = OTA_RESPONSE_HEADER_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
if (!this->password_.empty()) {
|
||||
this->client_.write(OTA_RESPONSE_REQUEST_AUTH);
|
||||
buf[0] = OTA_RESPONSE_REQUEST_AUTH;
|
||||
this->writeall_(buf, 1);
|
||||
MD5Builder md5_builder{};
|
||||
md5_builder.begin();
|
||||
sprintf(sbuf, "%08X", random_uint32());
|
||||
|
@ -113,7 +157,7 @@ void OTAComponent::handle_() {
|
|||
ESP_LOGV(TAG, "Auth: Nonce is %s", sbuf);
|
||||
|
||||
// Send nonce, 32 bytes hex MD5
|
||||
if (this->client_.write(reinterpret_cast<uint8_t *>(sbuf), 32) != 32) {
|
||||
if (!this->writeall_(reinterpret_cast<uint8_t *>(sbuf), 32)) {
|
||||
ESP_LOGW(TAG, "Auth: Writing nonce failed!");
|
||||
goto error;
|
||||
}
|
||||
|
@ -125,7 +169,7 @@ void OTAComponent::handle_() {
|
|||
md5_builder.add(sbuf);
|
||||
|
||||
// Receive cnonce, 32 bytes hex MD5
|
||||
if (!this->wait_receive_(buf, 32)) {
|
||||
if (!this->readall_(buf, 32)) {
|
||||
ESP_LOGW(TAG, "Auth: Reading cnonce failed!");
|
||||
goto error;
|
||||
}
|
||||
|
@ -140,7 +184,7 @@ void OTAComponent::handle_() {
|
|||
ESP_LOGV(TAG, "Auth: Result is %s", sbuf);
|
||||
|
||||
// Receive result, 32 bytes hex MD5
|
||||
if (!this->wait_receive_(buf + 64, 32)) {
|
||||
if (!this->writeall_(buf + 64, 32)) {
|
||||
ESP_LOGW(TAG, "Auth: Reading response failed!");
|
||||
goto error;
|
||||
}
|
||||
|
@ -159,10 +203,11 @@ void OTAComponent::handle_() {
|
|||
}
|
||||
|
||||
// Acknowledge auth OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_AUTH_OK);
|
||||
buf[0] = OTA_RESPONSE_AUTH_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
// Read size, 4 bytes MSB first
|
||||
if (!this->wait_receive_(buf, 4)) {
|
||||
if (!this->readall_(buf, 4)) {
|
||||
ESP_LOGW(TAG, "Reading size failed!");
|
||||
goto error;
|
||||
}
|
||||
|
@ -211,10 +256,11 @@ void OTAComponent::handle_() {
|
|||
update_started = true;
|
||||
|
||||
// Acknowledge prepare OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_UPDATE_PREPARE_OK);
|
||||
buf[0] = OTA_RESPONSE_UPDATE_PREPARE_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
// Read binary MD5, 32 bytes
|
||||
if (!this->wait_receive_(buf, 32)) {
|
||||
if (!this->readall_(buf, 32)) {
|
||||
ESP_LOGW(TAG, "Reading binary MD5 checksum failed!");
|
||||
goto error;
|
||||
}
|
||||
|
@ -223,17 +269,22 @@ void OTAComponent::handle_() {
|
|||
Update.setMD5(sbuf);
|
||||
|
||||
// Acknowledge MD5 OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_BIN_MD5_OK);
|
||||
buf[0] = OTA_RESPONSE_BIN_MD5_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
while (!Update.isFinished()) {
|
||||
size_t available = this->wait_receive_(buf, 0);
|
||||
if (!available) {
|
||||
// TODO: timeout check
|
||||
ssize_t read = this->client_->read(buf, sizeof(buf));
|
||||
if (read == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK)
|
||||
continue;
|
||||
ESP_LOGW(TAG, "Error receiving data for update, errno: %d", errno);
|
||||
goto error;
|
||||
}
|
||||
|
||||
uint32_t written = Update.write(buf, available);
|
||||
if (written != available) {
|
||||
ESP_LOGW(TAG, "Error writing binary data to flash: %u != %u!", written, available); // NOLINT
|
||||
uint32_t written = Update.write(buf, read);
|
||||
if (written != read) {
|
||||
ESP_LOGW(TAG, "Error writing binary data to flash: %u != %u!", written, read); // NOLINT
|
||||
error_code = OTA_RESPONSE_ERROR_WRITING_FLASH;
|
||||
goto error;
|
||||
}
|
||||
|
@ -253,7 +304,8 @@ void OTAComponent::handle_() {
|
|||
}
|
||||
|
||||
// Acknowledge receive OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_RECEIVE_OK);
|
||||
buf[0] = OTA_RESPONSE_RECEIVE_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
if (!Update.end()) {
|
||||
error_code = OTA_RESPONSE_ERROR_UPDATE_END;
|
||||
|
@ -261,16 +313,17 @@ void OTAComponent::handle_() {
|
|||
}
|
||||
|
||||
// Acknowledge Update end OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_UPDATE_END_OK);
|
||||
buf[0] = OTA_RESPONSE_UPDATE_END_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
// Read ACK
|
||||
if (!this->wait_receive_(buf, 1, false) || buf[0] != OTA_RESPONSE_OK) {
|
||||
if (!this->readall_(buf, 1) || buf[0] != OTA_RESPONSE_OK) {
|
||||
ESP_LOGW(TAG, "Reading back acknowledgement failed!");
|
||||
// do not go to error, this is not fatal
|
||||
}
|
||||
|
||||
this->client_.flush();
|
||||
this->client_.stop();
|
||||
this->client_->close();
|
||||
this->client_ = nullptr;
|
||||
delay(10);
|
||||
ESP_LOGI(TAG, "OTA update finished!");
|
||||
this->status_clear_warning();
|
||||
|
@ -286,11 +339,10 @@ error:
|
|||
Update.printError(ss);
|
||||
ESP_LOGW(TAG, "Update end failed! Error: %s", ss.c_str());
|
||||
}
|
||||
if (this->client_.connected()) {
|
||||
this->client_.write(static_cast<uint8_t>(error_code));
|
||||
this->client_.flush();
|
||||
}
|
||||
this->client_.stop();
|
||||
buf[0] = static_cast<uint8_t>(error_code);
|
||||
this->writeall_(buf, 1);
|
||||
this->client_->close();
|
||||
this->client_ = nullptr;
|
||||
|
||||
#ifdef ARDUINO_ARCH_ESP32
|
||||
if (update_started) {
|
||||
|
@ -314,52 +366,56 @@ error:
|
|||
#endif
|
||||
}
|
||||
|
||||
size_t OTAComponent::wait_receive_(uint8_t *buf, size_t bytes, bool check_disconnected) {
|
||||
size_t available = 0;
|
||||
bool OTAComponent::readall_(uint8_t *buf, size_t len) {
|
||||
uint32_t start = millis();
|
||||
do {
|
||||
App.feed_wdt();
|
||||
if (check_disconnected && !this->client_.connected()) {
|
||||
ESP_LOGW(TAG, "Error client disconnected while receiving data!");
|
||||
return 0;
|
||||
}
|
||||
int availi = this->client_.available();
|
||||
if (availi < 0) {
|
||||
ESP_LOGW(TAG, "Error reading data!");
|
||||
return 0;
|
||||
}
|
||||
uint32_t at = 0;
|
||||
while (len - at > 0) {
|
||||
uint32_t now = millis();
|
||||
if (availi == 0 && now - start > 10000) {
|
||||
ESP_LOGW(TAG, "Timeout waiting for data!");
|
||||
return 0;
|
||||
if (now - start > 1000) {
|
||||
ESP_LOGW(TAG, "Timed out reading %d bytes of data", len);
|
||||
return false;
|
||||
}
|
||||
available = size_t(availi);
|
||||
yield();
|
||||
} while (bytes == 0 ? available == 0 : available < bytes);
|
||||
|
||||
if (bytes == 0)
|
||||
bytes = std::min(available, size_t(1024));
|
||||
|
||||
bool success = false;
|
||||
for (uint32_t i = 0; !success && i < 100; i++) {
|
||||
int res = this->client_.read(buf, bytes);
|
||||
|
||||
if (res != int(bytes)) {
|
||||
// ESP32 implementation has an issue where calling read can fail with EAGAIN (race condition)
|
||||
// so just re-try it until it works (with generous timeout of 1s)
|
||||
// because we check with available() first this should not cause us any trouble in all other cases
|
||||
delay(10);
|
||||
ssize_t read = this->client_->read(buf + at, len - at);
|
||||
if (read == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
delay(1);
|
||||
continue;
|
||||
}
|
||||
ESP_LOGW(TAG, "Failed to read %d bytes of data, errno: %d", len, errno);
|
||||
return false;
|
||||
} else {
|
||||
success = true;
|
||||
at += read;
|
||||
}
|
||||
delay(1);
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
ESP_LOGW(TAG, "Reading %u bytes of binary data failed!", bytes); // NOLINT
|
||||
return 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool OTAComponent::writeall_(const uint8_t *buf, size_t len) {
|
||||
uint32_t start = millis();
|
||||
uint32_t at = 0;
|
||||
while (len - at > 0) {
|
||||
uint32_t now = millis();
|
||||
if (now - start > 1000) {
|
||||
ESP_LOGW(TAG, "Timed out writing %d bytes of data", len);
|
||||
return false;
|
||||
}
|
||||
|
||||
return bytes;
|
||||
ssize_t written = this->client_->write(buf + at, len - at);
|
||||
if (written == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
delay(1);
|
||||
continue;
|
||||
}
|
||||
ESP_LOGW(TAG, "Failed to write %d bytes of data, errno: %d", len, errno);
|
||||
return false;
|
||||
} else {
|
||||
at += written;
|
||||
}
|
||||
delay(1);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void OTAComponent::set_auth_password(const std::string &password) { this->password_ = password; }
|
||||
|
|
|
@ -3,8 +3,7 @@
|
|||
#include "esphome/core/component.h"
|
||||
#include "esphome/core/preferences.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
#include <WiFiServer.h>
|
||||
#include <WiFiClient.h>
|
||||
#include "esphome/components/socket/socket.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace ota {
|
||||
|
@ -74,14 +73,15 @@ class OTAComponent : public Component {
|
|||
uint32_t read_rtc_();
|
||||
|
||||
void handle_();
|
||||
size_t wait_receive_(uint8_t *buf, size_t bytes, bool check_disconnected = true);
|
||||
bool readall_(uint8_t *buf, size_t len);
|
||||
bool writeall_(const uint8_t *buf, size_t len);
|
||||
|
||||
std::string password_;
|
||||
|
||||
uint16_t port_;
|
||||
|
||||
WiFiServer *server_{nullptr};
|
||||
WiFiClient client_{};
|
||||
std::unique_ptr<socket::Socket> server_;
|
||||
std::unique_ptr<socket::Socket> client_;
|
||||
|
||||
bool has_safe_mode_{false}; ///< stores whether safe mode can be enabled.
|
||||
uint32_t safe_mode_start_time_; ///< stores when safe mode was enabled.
|
||||
|
|
2
esphome/components/socket/__init__.py
Normal file
2
esphome/components/socket/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
# Dummy package to allow components to depend on network
|
||||
CODEOWNERS = ["@esphome/core"]
|
45
esphome/components/socket/socket.h
Normal file
45
esphome/components/socket/socket.h
Normal file
|
@ -0,0 +1,45 @@
|
|||
#pragma once
|
||||
#include <string>
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <memory>
|
||||
|
||||
#include "esphome/core/optional.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
using socklen_t = uint32_t;
|
||||
|
||||
class Socket {
|
||||
public:
|
||||
Socket() = default;
|
||||
virtual ~Socket() = default;
|
||||
Socket(const Socket&) = delete;
|
||||
Socket &operator=(const Socket &) = delete;
|
||||
|
||||
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
virtual int close() = 0;
|
||||
virtual int connect(const std::string &address) = 0;
|
||||
virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
virtual int shutdown(int how) = 0;
|
||||
|
||||
virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
virtual std::string getpeername() = 0;
|
||||
virtual int getsockname(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
virtual std::string getsockname() = 0;
|
||||
virtual int getsockopt(int level, int optname, void *optval, socklen_t *optlen) = 0;
|
||||
virtual int setsockopt(int level, int optname, const void *optval, socklen_t optlen) = 0;
|
||||
virtual int listen(int backlog) = 0;
|
||||
virtual ssize_t read(void *buf, size_t len) = 0;
|
||||
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
|
||||
virtual ssize_t write(const void *buf, size_t len) = 0;
|
||||
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
|
||||
virtual int setblocking(bool blocking) = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<Socket> socket(int domain, int type, int protocol);
|
||||
|
||||
} // socket
|
||||
} // esphome
|
118
esphome/components/socket/socket_impl.cpp
Normal file
118
esphome/components/socket/socket_impl.cpp
Normal file
|
@ -0,0 +1,118 @@
|
|||
#include "socket.h"
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <string.h>
|
||||
|
||||
namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
std::string format_sockaddr(const struct sockaddr_storage &storage) {
|
||||
if (storage.ss_family == AF_INET) {
|
||||
const struct sockaddr_in *addr = reinterpret_cast<const struct sockaddr_in *>(&storage);
|
||||
char buf[INET_ADDRSTRLEN];
|
||||
const char *ret = inet_ntop(AF_INET, &addr->sin_addr, buf, sizeof(buf));
|
||||
if (ret == NULL)
|
||||
return {};
|
||||
return std::string{buf};
|
||||
} else if (storage.ss_family == AF_INET6) {
|
||||
const struct sockaddr_in6 *addr = reinterpret_cast<const struct sockaddr_in6 *>(&storage);
|
||||
char buf[INET6_ADDRSTRLEN];
|
||||
const char *ret = inet_ntop(AF_INET6, &addr->sin6_addr, buf, sizeof(buf));
|
||||
if (ret == NULL)
|
||||
return {};
|
||||
return std::string{buf};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
class SocketImplSocket : public Socket {
|
||||
public:
|
||||
SocketImplSocket(int fd) : Socket(), fd_(fd) {}
|
||||
~SocketImplSocket() override {
|
||||
if (!closed_) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
int fd = ::accept(fd_, addr, addrlen);
|
||||
if (fd == -1)
|
||||
return {};
|
||||
return std::unique_ptr<SocketImplSocket>{new SocketImplSocket(fd)};
|
||||
}
|
||||
int bind(const struct sockaddr *addr, socklen_t addrlen) override {
|
||||
return ::bind(fd_, addr, addrlen);
|
||||
}
|
||||
int close() override {
|
||||
int ret = ::close(fd_);
|
||||
closed_ = true;
|
||||
return ret;
|
||||
}
|
||||
int connect(const std::string &address) override {
|
||||
// TODO
|
||||
return 0;
|
||||
}
|
||||
int connect(const struct sockaddr *addr, socklen_t addrlen) override {
|
||||
return ::connect(fd_, addr, addrlen);
|
||||
}
|
||||
int shutdown(int how) override {
|
||||
return ::shutdown(fd_, how);
|
||||
}
|
||||
|
||||
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
return ::getpeername(fd_, addr, addrlen);
|
||||
}
|
||||
std::string getpeername() override {
|
||||
struct sockaddr_storage storage;
|
||||
socklen_t len = sizeof(storage);
|
||||
int err = this->getpeername((struct sockaddr *) &storage, &len);
|
||||
if (err != 0)
|
||||
return {};
|
||||
return format_sockaddr(storage);
|
||||
}
|
||||
int getsockname(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
return ::getsockname(fd_, addr, addrlen);
|
||||
}
|
||||
std::string getsockname() override {
|
||||
struct sockaddr_storage storage;
|
||||
socklen_t len = sizeof(storage);
|
||||
int err = this->getsockname((struct sockaddr *) &storage, &len);
|
||||
if (err != 0)
|
||||
return {};
|
||||
return format_sockaddr(storage);
|
||||
}
|
||||
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
|
||||
return ::getsockopt(fd_, level, optname, optval, optlen);
|
||||
}
|
||||
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
|
||||
return ::setsockopt(fd_, level, optname, optval, optlen);
|
||||
}
|
||||
int listen(int backlog) override {
|
||||
return ::listen(fd_, backlog);
|
||||
}
|
||||
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
|
||||
ssize_t read(void *buf, size_t len) override {
|
||||
return ::read(fd_, buf, len);
|
||||
}
|
||||
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
|
||||
ssize_t write(const void *buf, size_t len) override {
|
||||
return ::write(fd_, buf, len);
|
||||
}
|
||||
int setblocking(bool blocking) override {
|
||||
int fl = ::fcntl(fd_, F_GETFL, 0);
|
||||
::fcntl(fd_, F_SETFL, fl | O_NONBLOCK);
|
||||
return 0;
|
||||
}
|
||||
protected:
|
||||
int fd_;
|
||||
bool closed_ = false;
|
||||
};
|
||||
|
||||
std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
|
||||
int ret = ::socket(domain, type, protocol);
|
||||
if (ret == -1)
|
||||
return nullptr;
|
||||
return std::unique_ptr<Socket>{new SocketImplSocket(ret)};
|
||||
}
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
1
esphome/components/ssl/__init__.py
Normal file
1
esphome/components/ssl/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
AUTO_LOAD = ["socket"]
|
233
esphome/components/ssl/mbedtls_impl.cpp
Normal file
233
esphome/components/ssl/mbedtls_impl.cpp
Normal file
|
@ -0,0 +1,233 @@
|
|||
#include "ssl_context.h"
|
||||
#include <string.h>
|
||||
#include "mbedtls/platform.h"
|
||||
#include "mbedtls/net_sockets.h"
|
||||
#include "mbedtls/esp_debug.h"
|
||||
#include "mbedtls/ssl.h"
|
||||
#include "mbedtls/entropy.h"
|
||||
#include "mbedtls/ctr_drbg.h"
|
||||
#include "mbedtls/error.h"
|
||||
#include "mbedtls/certs.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace ssl {
|
||||
|
||||
static int entropy_hw_random_source(void *data, uint8_t *output, size_t len, size_t *olen) {
|
||||
esp_fill_random(output, len);
|
||||
*olen = len;
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct MbedTLSBioCtx {
|
||||
socket::Socket *sock;
|
||||
|
||||
static int send(void *raw, const uint8_t *buf, size_t len) {
|
||||
auto *ctx = reinterpret_cast<MbedTLSBioCtx *>(raw);
|
||||
ssize_t ret = ctx->sock->write(buf, len);
|
||||
if (ret != -1)
|
||||
return ret;
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN)
|
||||
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
||||
if (errno == EPIPE || errno == ECONNRESET)
|
||||
return MBEDTLS_ERR_NET_CONN_RESET;
|
||||
if (errno == EINTR)
|
||||
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
||||
return MBEDTLS_ERR_NET_SEND_FAILED;
|
||||
}
|
||||
static int recv(void *raw, uint8_t *buf, size_t len) {
|
||||
auto *ctx = reinterpret_cast<MbedTLSBioCtx *>(raw);
|
||||
ssize_t ret = ctx->sock->read(buf, len);
|
||||
if (ret != -1)
|
||||
return ret;
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN)
|
||||
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
||||
if (errno == EPIPE || errno == ECONNRESET)
|
||||
return MBEDTLS_ERR_NET_CONN_RESET;
|
||||
if (errno == EINTR)
|
||||
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
||||
return MBEDTLS_ERR_NET_SEND_FAILED;
|
||||
}
|
||||
};
|
||||
|
||||
class MbedTLSWrappedSocket : public socket::Socket {
|
||||
public:
|
||||
MbedTLSWrappedSocket(std::unique_ptr<socket::Socket> sock)
|
||||
: socket::Socket(), sock_(std::move(sock)) {}
|
||||
~MbedTLSWrappedSocket() override {
|
||||
mbedtls_ssl_free(&ssl_);
|
||||
sock_ = nullptr;
|
||||
}
|
||||
void init(const mbedtls_ssl_config *conf) {
|
||||
// TODO: reuse ssl contexts?
|
||||
mbedtls_ssl_init(&ssl_);
|
||||
mbedtls_ssl_setup(&ssl_, conf);
|
||||
// sock pointer does not fit in void*
|
||||
// instead store it in a heap-allocated var
|
||||
auto *ctx = new MbedTLSBioCtx;
|
||||
// unsafe, but should be fine because we free before sock is reset
|
||||
ctx->sock = sock_.get();
|
||||
mbedtls_ssl_set_bio(&ssl_, ctx, MbedTLSBioCtx::send, MbedTLSBioCtx::recv, nullptr);
|
||||
|
||||
do_handshake_ = true;
|
||||
}
|
||||
|
||||
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
// only for server sockets
|
||||
errno = EBADF;
|
||||
return {};
|
||||
}
|
||||
int bind(const struct sockaddr *addr, socklen_t addrlen) override {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
int close() override {
|
||||
return sock_->close();
|
||||
}
|
||||
int connect(const std::string &address) override {
|
||||
return sock_->connect(address);
|
||||
}
|
||||
int connect(const struct sockaddr *addr, socklen_t addrlen) override {
|
||||
return sock_->connect(addr, addrlen);
|
||||
}
|
||||
int shutdown(int how) override {
|
||||
int ret = mbedtls_ssl_close_notify(&ssl_);
|
||||
if (ret != 0)
|
||||
return this->mbedtls_to_errno_(ret);
|
||||
return this->sock_->shutdown(how);
|
||||
}
|
||||
|
||||
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
return sock_->getpeername(addr, addrlen);
|
||||
}
|
||||
std::string getpeername() override {
|
||||
return sock_->getpeername();
|
||||
}
|
||||
int getsockname(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
return sock_->getsockname(addr, addrlen);
|
||||
}
|
||||
std::string getsockname() override {
|
||||
return sock_->getsockname();
|
||||
}
|
||||
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
|
||||
return sock_->getsockopt(level, optname, optval, optlen);
|
||||
}
|
||||
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
|
||||
return sock_->setsockopt(level, optname, optval, optlen);
|
||||
}
|
||||
int listen(int backlog) override {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
ssize_t read(void *buf, size_t len) override {
|
||||
// mbedtls will automatically perform handshake here if necessary
|
||||
int ret = mbedtls_ssl_read(&ssl_, reinterpret_cast<uint8_t *>(buf), len);
|
||||
return this->mbedtls_to_errno_(ret);
|
||||
}
|
||||
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
|
||||
ssize_t write(const void *buf, size_t len) override {
|
||||
int ret = mbedtls_ssl_write(&ssl_, reinterpret_cast<const uint8_t *>(buf), len);
|
||||
return this->mbedtls_to_errno_(ret);
|
||||
}
|
||||
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
|
||||
int setblocking(bool blocking) override {
|
||||
// TODO: handle blocking modes
|
||||
return sock_->setblocking(blocking);
|
||||
}
|
||||
|
||||
protected:
|
||||
int mbedtls_to_errno_(int ret) {
|
||||
if (ret > 0) {
|
||||
return ret;
|
||||
} else if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
|
||||
errno = EWOULDBLOCK;
|
||||
return -1;
|
||||
} else if (ret == MBEDTLS_ERR_NET_CONN_RESET) {
|
||||
errno = ECONNRESET;
|
||||
return -1;
|
||||
} else {
|
||||
errno = EIO;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<socket::Socket> sock_;
|
||||
mbedtls_ssl_context ssl_;
|
||||
bool do_handshake_ = false;
|
||||
};
|
||||
|
||||
class MbedTLSContext : public SSLContext {
|
||||
public:
|
||||
MbedTLSContext() = default;
|
||||
~MbedTLSContext() override {
|
||||
mbedtls_pk_free(&pkey_);
|
||||
mbedtls_entropy_free(&entropy_);
|
||||
mbedtls_ctr_drbg_free(&ctr_drbg_);
|
||||
mbedtls_x509_crt_free(&srv_cert_);
|
||||
mbedtls_ssl_config_free(&conf_);
|
||||
}
|
||||
|
||||
void set_server_certificate(const char *cert) override {
|
||||
this->srv_cert_str_ = cert;
|
||||
}
|
||||
void set_private_key(const char *private_key) override {
|
||||
this->privkey_str_ = private_key;
|
||||
}
|
||||
|
||||
int init() override {
|
||||
mbedtls_x509_crt_init(&srv_cert_);
|
||||
mbedtls_ctr_drbg_init(&ctr_drbg_);
|
||||
mbedtls_entropy_init(&entropy_);
|
||||
mbedtls_pk_init(&pkey_);
|
||||
mbedtls_ssl_config_init(&conf_);
|
||||
|
||||
// TODO check what this does
|
||||
mbedtls_entropy_add_source(&entropy_, entropy_hw_random_source, NULL, 134, MBEDTLS_ENTROPY_SOURCE_STRONG);
|
||||
mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, NULL, 0);
|
||||
|
||||
mbedtls_x509_crt_parse(
|
||||
&srv_cert_,
|
||||
reinterpret_cast<const uint8_t *>(srv_cert_str_),
|
||||
strlen(srv_cert_str_)
|
||||
);
|
||||
|
||||
mbedtls_pk_parse_key(
|
||||
&pkey_,
|
||||
reinterpret_cast<const uint8_t *>(privkey_str_),
|
||||
strlen(privkey_str_),
|
||||
nullptr,
|
||||
0
|
||||
);
|
||||
|
||||
mbedtls_ssl_config_defaults(
|
||||
&conf_,
|
||||
MBEDTLS_SSL_IS_SERVER,
|
||||
MBEDTLS_SSL_TRANSPORT_STREAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT
|
||||
);
|
||||
mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_);
|
||||
mbedtls_ssl_conf_own_cert(&conf_, &srv_cert_, &pkey_);
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::unique_ptr<socket::Socket> wrap_socket(std::unique_ptr<socket::Socket> sock) override {
|
||||
auto *wrapped = new MbedTLSWrappedSocket(std::move(sock));
|
||||
wrapped->init(&conf_);
|
||||
return std::unique_ptr<socket::Socket>{wrapped};
|
||||
}
|
||||
|
||||
protected:
|
||||
const char *srv_cert_str_ = nullptr;
|
||||
const char *privkey_str_ = nullptr;
|
||||
mbedtls_entropy_context entropy_;
|
||||
mbedtls_ctr_drbg_context ctr_drbg_;
|
||||
mbedtls_x509_crt srv_cert_;
|
||||
mbedtls_pk_context pkey_;
|
||||
mbedtls_ssl_config conf_;
|
||||
};
|
||||
|
||||
std::unique_ptr<SSLContext> create_context() {
|
||||
return std::unique_ptr<SSLContext>{new MbedTLSContext()};
|
||||
}
|
||||
|
||||
} // namespace ssl
|
||||
} // namespace esphome
|
24
esphome/components/ssl/ssl_context.h
Normal file
24
esphome/components/ssl/ssl_context.h
Normal file
|
@ -0,0 +1,24 @@
|
|||
#pragma once
|
||||
#include <memory>
|
||||
#include "esphome/components/socket/socket.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace ssl {
|
||||
|
||||
class SSLContext {
|
||||
public:
|
||||
SSLContext() = default;
|
||||
virtual ~SSLContext() = default;
|
||||
SSLContext(const SSLContext&) = delete;
|
||||
SSLContext &operator=(const SSLContext &) = delete;
|
||||
|
||||
virtual int init() = 0;
|
||||
virtual void set_server_certificate(const char *cert) = 0;
|
||||
virtual void set_private_key(const char *private_key) = 0;
|
||||
virtual std::unique_ptr<socket::Socket> wrap_socket(std::unique_ptr<socket::Socket> sock) = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<SSLContext> create_context();
|
||||
|
||||
} // namespace ssl
|
||||
} // namespace esphome
|
Loading…
Reference in a new issue