Socket refactor and SSL

This commit is contained in:
Otto winter 2021-08-09 20:54:50 +02:00
parent ea4a458214
commit 40dd9c5dce
No known key found for this signature in database
GPG key ID: 48ED2DDB96D7682C
15 changed files with 877 additions and 161 deletions

View file

@ -192,6 +192,11 @@ class APIClient(threading.Thread):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10.0)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
import ssl
context = ssl.SSLContext()
self._socket = context.wrap_socket(self._socket)
try:
self._socket.connect((ip, self._port))
except OSError as err:

View file

@ -19,7 +19,7 @@ from esphome.const import (
from esphome.core import coroutine_with_priority
DEPENDENCIES = ["network"]
AUTO_LOAD = ["async_tcp"]
AUTO_LOAD = ["socket", "ssl"]
CODEOWNERS = ["@OttoWinter"]
api_ns = cg.esphome_ns.namespace("api")

View file

@ -2,6 +2,7 @@
#include "esphome/core/log.h"
#include "esphome/core/util.h"
#include "esphome/core/version.h"
#include <errno.h>
#ifdef USE_DEEP_SLEEP
#include "esphome/components/deep_sleep/deep_sleep_component.h"
@ -18,42 +19,62 @@ namespace api {
static const char *const TAG = "api.connection";
APIConnection::APIConnection(AsyncClient *client, APIServer *parent)
: client_(client), parent_(parent), initial_state_iterator_(parent, this), list_entities_iterator_(parent, this) {
APIConnection::APIConnection(std::unique_ptr<socket::Socket> sock, APIServer *parent)
: socket_(std::move(sock)), parent_(parent),
initial_state_iterator_(parent, this), list_entities_iterator_(parent, this) {
this->proto_write_buffer_.reserve(64);
this->recv_buffer_.reserve(32);
}
void APIConnection::start() {
/*
this->client_->onError([](void *s, AsyncClient *c, int8_t error) { ((APIConnection *) s)->on_error_(error); }, this);
this->client_->onDisconnect([](void *s, AsyncClient *c) { ((APIConnection *) s)->on_disconnect_(); }, this);
this->client_->onTimeout([](void *s, AsyncClient *c, uint32_t time) { ((APIConnection *) s)->on_timeout_(time); },
this);
this->client_->onData([](void *s, AsyncClient *c, void *buf,
size_t len) { ((APIConnection *) s)->on_data_(reinterpret_cast<uint8_t *>(buf), len); },
this);
this->send_buffer_.reserve(64);
this->recv_buffer_.reserve(32);
this->client_info_ = this->client_->remoteIP().toString().c_str();
this);*/
/*this->client_info_ = this->client_->remoteIP().toString().c_str();*/
this->last_traffic_ = millis();
int err = socket_->setblocking(false);
if (err != 0) {
ESP_LOGW(TAG, "Socket could not enable non-blocking, errno: %d", errno);
remove_ = true;
return;
}
int enable = 1;
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
if (err != 0) {
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
remove_ = true;
return;
}
}
APIConnection::~APIConnection() { delete this->client_; }
void APIConnection::on_error_(int8_t error) { this->remove_ = true; }
APIConnection::~APIConnection() { this->socket_ = nullptr; }
/*void APIConnection::on_error_(int8_t error) { this->remove_ = true; }
void APIConnection::on_disconnect_() { this->remove_ = true; }
void APIConnection::on_timeout_(uint32_t time) { this->on_fatal_error(); }
void APIConnection::on_data_(uint8_t *buf, size_t len) {
if (len == 0 || buf == nullptr)
return;
this->recv_buffer_.insert(this->recv_buffer_.end(), buf, buf + len);
}
}*/
/// Returns at estimate of how many bytes will need to be read for this message
void APIConnection::parse_recv_buffer_() {
if (this->recv_buffer_.empty() || this->remove_)
if (this->recv_len_ == 0 || this->remove_)
return;
while (!this->recv_buffer_.empty()) {
while (this->recv_len_ != 0) {
if (this->recv_buffer_[0] != 0x00) {
ESP_LOGW(TAG, "Invalid preamble from %s", this->client_info_.c_str());
this->on_fatal_error();
ESP_LOGW(TAG, "Invalid preamble from %s", this->client_info_.c_str());
return;
}
uint32_t i = 1;
const uint32_t size = this->recv_buffer_.size();
const uint32_t size = recv_len_;
uint32_t consumed;
auto msg_size_varint = ProtoVarInt::parse(&this->recv_buffer_[i], size - i, &consumed);
if (!msg_size_varint.has_value())
@ -77,15 +98,22 @@ void APIConnection::parse_recv_buffer_() {
this->read_message(msg_size, msg_type, msg);
if (this->remove_)
return;
// pop front
uint32_t total = i + msg_size;
this->recv_buffer_.erase(this->recv_buffer_.begin(), this->recv_buffer_.begin() + total);
uint32_t total_size = i + msg_size;
// left-rotate remaining data (if any) to beginning
// inefficient, but at the moment this only receives small packets anyway
std::copy(
this->recv_buffer_.begin() + total_size,
this->recv_buffer_.begin() + recv_len_,
this->recv_buffer_.begin()
);
this->recv_len_ -= total_size;
this->last_traffic_ = millis();
}
}
void APIConnection::disconnect_client() {
this->client_->close();
this->socket_->close();
this->remove_ = true;
}
@ -97,17 +125,49 @@ void APIConnection::loop() {
this->disconnect_client();
return;
}
if (!network_is_connected()) {
// when network is disconnected force disconnect immediately
// don't wait for timeout
this->on_fatal_error();
return;
}
if (this->client_->disconnected()) {
// failsafe for disconnect logic
this->on_disconnect_();
return;
this->try_send_pending_data_();
while (!this->remove_) {
// Note: vector.capacity is not used, as there's no good way to insert
// data into it without zero-initialising on resize
// https://stackoverflow.com/a/7689457
size_t capacity = this->recv_buffer_.size();
size_t used = this->recv_len_;
size_t space = capacity - used;
uint8_t *head = &this->recv_buffer_[used];
if (space == 0) {
// no space to read, allocate more then retry
this->recv_buffer_.resize(capacity + 64);
continue;
}
ssize_t received = socket_->read(head, space);
if (received == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// read would block
break;
}
this->on_fatal_error();
ESP_LOGW(TAG, "Error reading from socket: errno %d", errno);
break;
} else if (received == 0) {
break;
}
// ESP_LOGD(TAG, "received %s", hexencode(head, received).c_str());
this->recv_len_ += received;
if (received != space)
// done with reading
break;
this->parse_recv_buffer_();
}
this->parse_recv_buffer_();
@ -115,18 +175,20 @@ void APIConnection::loop() {
this->initial_state_iterator_.advance();
const uint32_t keepalive = 60000;
const uint32_t now = millis();
if (this->sent_ping_) {
// Disconnect if not responded within 2.5*keepalive
if (millis() - this->last_traffic_ > (keepalive * 5) / 2) {
if (now - this->last_traffic_ > (keepalive * 5) / 2) {
ESP_LOGW(TAG, "'%s' didn't respond to ping request in time. Disconnecting...", this->client_info_.c_str());
this->disconnect_client();
}
} else if (millis() - this->last_traffic_ > keepalive) {
} else if (now - this->last_traffic_ > keepalive) {
this->sent_ping_ = true;
this->send_ping_request(PingRequest());
}
#ifdef USE_ESP32_CAMERA
// FIXME
if (this->image_reader_.available()) {
uint32_t space = this->client_->space();
// reserve 15 bytes for metadata, and at least 64 bytes of data
@ -152,6 +214,30 @@ void APIConnection::loop() {
}
#endif
}
void APIConnection::try_send_pending_data_() {
if (this->pending_send_buffer_.empty() || this->remove_)
return;
const uint8_t *data = &this->pending_send_buffer_[0];
const size_t len = this->pending_send_buffer_.size();
ssize_t written = this->socket_->write(data, len);
if (written == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// write would block
return;
}
this->on_fatal_error();
ESP_LOGW(TAG, "Error writing to socket: errno %d", errno);
return;
} else if (written == len) {
this->pending_send_buffer_.clear();
} else {
// FIXME: inefficient
this->pending_send_buffer_.erase(
this->pending_send_buffer_.begin(),
this->pending_send_buffer_.begin() + written
);
}
}
std::string get_default_unique_id(const std::string &component_type, Nameable *nameable) {
return App.get_name() + component_type + nameable->get_object_id();
@ -703,8 +789,7 @@ bool APIConnection::send_log_message(int level, const char *tag, const char *lin
}
HelloResponse APIConnection::hello(const HelloRequest &msg) {
this->client_info_ = msg.client_info + " (" + this->client_->remoteIP().toString().c_str();
this->client_info_ += ")";
this->client_info_ = msg.client_info + " (" + this->socket_->getpeername() + ")";
ESP_LOGV(TAG, "Hello from client: '%s'", this->client_info_.c_str());
HelloResponse resp;
@ -779,9 +864,47 @@ void APIConnection::subscribe_home_assistant_states(const SubscribeHomeAssistant
}
}
}
bool APIConnection::send_(const void *buf, size_t len, bool force) {
if (this->remove_)
return false;
if (len == 0)
return true;
// ESP_LOGD(TAG, "writing %s", hexencode((const uint8_t *) buf, len).c_str());
ssize_t written = this->socket_->write(buf, len);
bool add_to_pending = false;
if (written == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
// write would block
add_to_pending = force;
} else if (written == -1) {
this->on_fatal_error();
ESP_LOGW(TAG, "Error writing to socket: errno %d", errno);
return false;
} else if (written == len) {
// all written
return true;
} else if (written == 0) {
add_to_pending = force;
} else {
// partially written, must insert in pending data
add_to_pending = true;
}
if (add_to_pending) {
this->pending_send_buffer_.insert(
this->pending_send_buffer_.end(),
reinterpret_cast<const uint8_t *>(buf) + written,
reinterpret_cast<const uint8_t *>(buf) + len
);
}
return add_to_pending;
}
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) {
if (this->remove_)
return false;
this->try_send_pending_data_();
if (!this->pending_send_buffer_.empty())
// FIXME: still send important message like HA service
return false;
std::vector<uint8_t> header;
header.push_back(0x00);
@ -790,36 +913,25 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type)
size_t needed_space = buffer.get_buffer()->size() + header.size();
if (needed_space > this->client_->space()) {
delay(0);
if (needed_space > this->client_->space()) {
// SubscribeLogsResponse
if (message_type != 29) {
ESP_LOGV(TAG, "Cannot send message because of TCP buffer space");
}
delay(0);
return false;
}
if (!this->send_(header.data(), header.size(), false)) {
// nothing written and doesn't fit
return false;
}
this->client_->add(reinterpret_cast<char *>(header.data()), header.size(),
ASYNC_WRITE_FLAG_COPY | ASYNC_WRITE_FLAG_MORE);
this->client_->add(reinterpret_cast<char *>(buffer.get_buffer()->data()), buffer.get_buffer()->size(),
ASYNC_WRITE_FLAG_COPY);
bool ret = this->client_->send();
return ret;
// force send because we already sent the header
this->send_(buffer.get_buffer()->data(), buffer.get_buffer()->size(), true);
return true;
}
void APIConnection::on_unauthenticated_access() {
ESP_LOGD(TAG, "'%s' tried to access without authentication.", this->client_info_.c_str());
this->on_fatal_error();
ESP_LOGD(TAG, "'%s' tried to access without authentication.", this->client_info_.c_str());
}
void APIConnection::on_no_setup_connection() {
ESP_LOGD(TAG, "'%s' tried to access without full connection.", this->client_info_.c_str());
this->on_fatal_error();
ESP_LOGD(TAG, "'%s' tried to access without full connection.", this->client_info_.c_str());
}
void APIConnection::on_fatal_error() {
ESP_LOGV(TAG, "Error: Disconnecting %s", this->client_info_.c_str());
this->client_->close();
this->socket_->close();
this->remove_ = true;
}

View file

@ -11,9 +11,10 @@ namespace api {
class APIConnection : public APIServerConnection {
public:
APIConnection(AsyncClient *client, APIServer *parent);
APIConnection(std::unique_ptr<socket::Socket> socket, APIServer *parent);
virtual ~APIConnection();
void start();
void disconnect_client();
void loop();
@ -135,19 +136,18 @@ class APIConnection : public APIServerConnection {
void on_unauthenticated_access() override;
void on_no_setup_connection() override;
ProtoWriteBuffer create_buffer() override {
this->send_buffer_.clear();
return {&this->send_buffer_};
// FIXME: ensure no recursive writes can happen
this->proto_write_buffer_.clear();
return {&this->proto_write_buffer_};
}
bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override;
protected:
friend APIServer;
void on_error_(int8_t error);
void on_disconnect_();
void on_timeout_(uint32_t time);
void on_data_(uint8_t *buf, size_t len);
void parse_recv_buffer_();
bool send_(const void *buf, size_t len, bool force);
void try_send_pending_data_();
enum class ConnectionState {
WAITING_FOR_HELLO,
@ -157,8 +157,14 @@ class APIConnection : public APIServerConnection {
bool remove_{false};
std::vector<uint8_t> send_buffer_;
// Buffer used to encode proto messages
// Re-use to prevent allocations
std::vector<uint8_t> proto_write_buffer_;
// Buffer containing pending sends
std::vector<uint8_t> pending_send_buffer_;
// Buffer containing data that was received but not parsed yet
std::vector<uint8_t> recv_buffer_;
size_t recv_len_{0};
std::string client_info_;
#ifdef USE_ESP32_CAMERA
@ -170,9 +176,8 @@ class APIConnection : public APIServerConnection {
uint32_t last_traffic_;
bool sent_ping_{false};
bool service_call_subscription_{false};
bool current_nodelay_{false};
bool next_close_{false};
AsyncClient *client_;
std::unique_ptr<socket::Socket> socket_;
APIServer *parent_;
InitialStateIterator initial_state_iterator_;
ListEntitiesIterator list_entities_iterator_;

View file

@ -5,6 +5,8 @@
#include "esphome/core/util.h"
#include "esphome/core/defines.h"
#include "esphome/core/version.h"
#include <errno.h>
//#include <arpa/inet.h>
#ifdef USE_LOGGER
#include "esphome/components/logger/logger.h"
@ -21,20 +23,115 @@ static const char *const TAG = "api";
void APIServer::setup() {
ESP_LOGCONFIG(TAG, "Setting up Home Assistant API server...");
this->setup_controller();
this->server_ = AsyncServer(this->port_);
this->server_.setNoDelay(false);
this->server_.begin();
this->server_.onClient(
[](void *s, AsyncClient *client) {
if (client == nullptr)
return;
socket_ = socket::socket(AF_INET, SOCK_STREAM, 0);
if (socket_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket.");
this->mark_failed();
return;
}
int enable = 1;
int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
// we can still continue
}
err = socket_->setblocking(false);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
this->mark_failed();
return;
}
/*struct sockaddr_storage dest_addr;
memset(&dest_addr, 0, sizeof(dest_addr));
struct sockaddr_in *dest_addr_ip4 = (struct sockaddr_in *) &dest_addr;
dest_addr_ip4->sin_addr.s_addr = htonl(INADDR_ANY);
dest_addr_ip4->sin_family = AF_INET;
dest_addr_ip4->sin_port = htons(this->port_);
err = socket_->bind((struct sockaddr *) &dest_addr, sizeof(dest_addr));*/
struct sockaddr_in server;
memset(&server, 0, sizeof(server));
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(this->port_);
err = socket_->bind((struct sockaddr *) &server, sizeof(server));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
this->mark_failed();
return;
}
err = socket_->listen(4);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
this->mark_failed();
return;
}
ssl_ = ssl::create_context();
if (!ssl_) {
ESP_LOGW(TAG, "Failed to create SSL context: errno %d", errno);
this->mark_failed();
return;
}
ssl_->set_server_certificate(R"(-----BEGIN CERTIFICATE-----
MIIDETCCAfkCFGNtbm6nA3CZM7no7HqdWikhUMSkMA0GCSqGSIb3DQEBCwUAMEUx
CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
cm5ldCBXaWRnaXRzIFB0eSBMdGQwHhcNMjEwODA5MTgwOTMyWhcNMjEwOTA4MTgw
OTMyWjBFMQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UE
CgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOC
AQ8AMIIBCgKCAQEAwbt/qjWftqZtdRaJ5QjRf/8Sh6JT8KN4Bu9cGbHJIKAQLhy6
8/qdB24Ar8SyuKEaV8HRcCguTQ58jdK5rbaQu/Zpppgy9lF3AHH1MhVHavGNca3A
ejFtJr4DuTLkv/HjpgcAHjhZk+mFeNXrHeFrPIzF3imSyV1xyqoBxpa1cCFH/D3J
o2S6PMdAEcHSoaP5TEuM9e2j9Sc97LughMaFkR1R4cz2kEyMZIOASHkFCJMV6pjg
PVOqxu11oFYJn9/zh1Ea6PChYq+bGBmj60vwh+tpA6E8T0PzkxUuVklAD5pBfXoD
y8xW8ulc0CPSGSaxbn2vudUBJyZFvQBTQhFVcwIDAQABMA0GCSqGSIb3DQEBCwUA
A4IBAQBQM80osk+ryQ+CqBhyOLQBOeQkmCVNzMUjVBG7tP4vkuJtfAdyUuKBWGtr
X2VrkL0yueeDt9rdib5QbXWih4sT7KdQlnSBmnrac0MM7wCh+lhCnJhWWCUBHP9s
8rkL2XOrISbVi80wqpJn0y4FMaoK6KnxyelallHuNZ+3EZAuZrhGAkV68Z83CIFO
5emvAIGq73U/lddLDV6sz7zWeDdnyfTpkLzml8wJLO9Ob7o6aw7WJK/edjYdc2XW
pIMatEESaN9MlWI5SXQS4AcMnqdUqab5587cHDrgVjBd8RmvdyT9j2v7nS2JyEK0
DkASogRqBmCLR1/0UW+dFARCZI9k
-----END CERTIFICATE-----
)");
ssl_->set_private_key(R"(-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAwbt/qjWftqZtdRaJ5QjRf/8Sh6JT8KN4Bu9cGbHJIKAQLhy6
8/qdB24Ar8SyuKEaV8HRcCguTQ58jdK5rbaQu/Zpppgy9lF3AHH1MhVHavGNca3A
ejFtJr4DuTLkv/HjpgcAHjhZk+mFeNXrHeFrPIzF3imSyV1xyqoBxpa1cCFH/D3J
o2S6PMdAEcHSoaP5TEuM9e2j9Sc97LughMaFkR1R4cz2kEyMZIOASHkFCJMV6pjg
PVOqxu11oFYJn9/zh1Ea6PChYq+bGBmj60vwh+tpA6E8T0PzkxUuVklAD5pBfXoD
y8xW8ulc0CPSGSaxbn2vudUBJyZFvQBTQhFVcwIDAQABAoIBAGVpQuDUhTBVWkLK
c5CC1zfLS+XYIVx8FZ57uZhxqjj70LxyqaKBc6Wp/Y4ExxFCs8lwWbP+NI59oNGU
l0HJqWXbDV75mOO7rTF8db+rx+DBZSs2quTL7rkzCjvt2jRn6KTGUVeAY9O7j/S6
9gKEN2BQyFsNJBtoYOKXr6pGxd9Vg3K0j7DJXf5uK7lWIrxtU9k7QgMJFdnhbzEu
0TnhFEVMDdBIm+yrTdL8lmIdT8DOUIJgyTJ1iICFndksPQSgBAWQaGKaaxZbn0c3
Oy778VFqT2HywHVbJQL4XBe/yYUhjbpF1Hv9EEbK3Rm04xsCDbZru6/AK88gHBk4
b7uUSwECgYEA4BmML1isP9h8zqAvCEFjFmAWpoLBBZ+5I9Go1PIWhlCYY+G7AUXw
zxb0J6d9UGsYTkJXlgE77+HBzqlgyhCkngNuAAPm37ebdwuy5iBr32c9RLahR5W5
Nh+J3le9JTXe9B9uwfggD06dBFmhgG0PQdyBr4Daa3a8VRJAD1MGYMECgYEA3U9U
QwxQOYBkdJTbIQnTP7vnFuhWn9V5BMn5PczJSwGJEgaHgIL5Bm5NHa/ON3UX6QIi
uk73fGfohN8Ii1MjVKNFKM/LZ30XSufVHrm7yH6xRR4qbZUk4KhKxV/uOVluv38P
bis9B9cye3ETnjDhkWK4/TJeTHHlTAKMQuOQzzMCgYEAmtlsYYbvNwq7aveKqDSu
aFarMBGnmOA+SP7ln4dMgzELq/DdjEqs1BwzR3dXgwsNd34mEVP2+5HOnqOxas7H
QRxzlPUdQjcX6NGfo56Bi5RF5MYheVp+6WQvmwCbhSvNTHivyr5OQOV8X/YjP5+c
bFEXF5N82cbo6gu7Uht3i8ECgYAh511JSEGiDYFWOte3IAI06VxlrgJXSiTYDvkX
9p9/1iRhlo57qZTs30kBG0XESTP4hlM7p41SibidYm20qm/nL3wQ3ISUvh0rZIjJ
xDp4ZLBTnmNxlj+oCyApTKD6ODE3NQfwIL+gy973+kK/IU3tL+qXH3hCzdAK7Pj/
5kzw8QKBgQDUYQGH1JT93Yn9uIyfX1v6HcB1azDbF16JEOFZoGlS1gxFCobIb7jA
2/Y0HfFUUfDGexjQNReFi0IXjgBvYmJX7rF9tGsTdXh35Lu2cTd0DcykGPVcFyJW
PSf0vGzbAqpdriYQStaed+HgTdW6kHsOBNeJbbJkjsQpoaoWX3tEDw==
-----END RSA PRIVATE KEY-----
)");
err = ssl_->init();
if (err != 0) {
ESP_LOGW(TAG, "Failed to initialize SSL context: errno %d", errno);
this->mark_failed();
return;
}
// can't print here because in lwIP thread
// ESP_LOGD(TAG, "New client connected from %s", client->remoteIP().toString().c_str());
auto *a_this = (APIServer *) s;
a_this->clients_.push_back(new APIConnection(client, a_this));
},
this);
#ifdef USE_LOGGER
if (logger::global_logger != nullptr) {
logger::global_logger->add_on_log_callback([this](int level, const char *tag, const char *message) {
@ -59,6 +156,27 @@ void APIServer::setup() {
#endif
}
void APIServer::loop() {
// Accept new clients
while (true) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len);
if (!sock)
break;
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
// wrap socket
auto sock2 = ssl_->wrap_socket(std::move(sock));
if (!sock2) {
ESP_LOGW(TAG, "Failed to wrap socket with SSL: errno %d", errno);
continue;
}
auto *conn = new APIConnection(std::move(sock2), this);
clients_.push_back(conn);
conn->start();
}
// Partition clients into remove and active
auto new_end =
std::partition(this->clients_.begin(), this->clients_.end(), [](APIConnection *conn) { return !conn->remove_; });

View file

@ -4,6 +4,8 @@
#include "esphome/core/controller.h"
#include "esphome/core/defines.h"
#include "esphome/core/log.h"
#include "esphome/components/socket/socket.h"
#include "esphome/components/ssl/ssl_context.h"
#include "api_pb2.h"
#include "api_pb2_service.h"
#include "util.h"
@ -11,13 +13,6 @@
#include "subscribe_state.h"
#include "user_services.h"
#ifdef ARDUINO_ARCH_ESP32
#include <AsyncTCP.h>
#endif
#ifdef ARDUINO_ARCH_ESP8266
#include <ESPAsyncTCP.h>
#endif
namespace esphome {
namespace api {
@ -86,7 +81,8 @@ class APIServer : public Component, public Controller {
const std::vector<UserServiceDescriptor *> &get_user_services() const { return this->user_services_; }
protected:
AsyncServer server_{0};
std::unique_ptr<socket::Socket> socket_ = nullptr;
std::unique_ptr<ssl::SSLContext> ssl_ = nullptr;
uint16_t port_{6053};
uint32_t reboot_timeout_{300000};
uint32_t last_connected_{0};

View file

@ -15,6 +15,7 @@ from esphome.core import CORE, coroutine_with_priority
CODEOWNERS = ["@esphome/core"]
DEPENDENCIES = ["network"]
AUTO_LOAD = ["socket"]
CONF_ON_STATE_CHANGE = "on_state_change"
CONF_ON_BEGIN = "on_begin"

View file

@ -3,7 +3,7 @@
#include "esphome/core/log.h"
#include "esphome/core/application.h"
#include "esphome/core/util.h"
#include <errno.h>
#include <cstdio>
#include <MD5Builder.h>
#ifdef ARDUINO_ARCH_ESP32
@ -19,8 +19,44 @@ static const char *const TAG = "ota";
static const uint8_t OTA_VERSION_1_0 = 1;
void OTAComponent::setup() {
this->server_ = new WiFiServer(this->port_);
this->server_->begin();
server_ = socket::socket(AF_INET, SOCK_STREAM, 0);
if (server_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket.");
this->mark_failed();
return;
}
int enable = 1;
int err = server_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
// we can still continue
}
err = server_->setblocking(false);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
this->mark_failed();
return;
}
struct sockaddr_in server;
memset(&server, 0, sizeof(server));
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(this->port_);
err = server_->bind((struct sockaddr *) &server, sizeof(server));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
this->mark_failed();
return;
}
err = server_->listen(4);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
this->mark_failed();
return;
}
this->dump_config();
}
@ -59,23 +95,28 @@ void OTAComponent::handle_() {
uint8_t ota_features;
(void) ota_features;
if (!this->client_.connected()) {
this->client_ = this->server_->available();
if (client_ == nullptr) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
client_ = server_->accept((struct sockaddr *) &source_addr, &addr_len);
}
if (client_ == nullptr)
return;
if (!this->client_.connected())
return;
int enable = 1;
int err = client_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
if (err != 0) {
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
return;
}
// enable nodelay for outgoing data
this->client_.setNoDelay(true);
ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_.remoteIP().toString().c_str());
ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_->getpeername().c_str());
this->status_set_warning();
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(OTA_STARTED, 0.0f, 0);
#endif
if (!this->wait_receive_(buf, 5)) {
if (!this->readall_(buf, 5)) {
ESP_LOGW(TAG, "Reading magic bytes failed!");
goto error;
}
@ -88,11 +129,12 @@ void OTAComponent::handle_() {
}
// Send OK and version - 2 bytes
this->client_.write(OTA_RESPONSE_OK);
this->client_.write(OTA_VERSION_1_0);
buf[0] = OTA_RESPONSE_OK;
buf[1] = OTA_VERSION_1_0;
this->writeall_(buf, 2);
// Read features - 1 byte
if (!this->wait_receive_(buf, 1)) {
if (!this->readall_(buf, 1)) {
ESP_LOGW(TAG, "Reading features failed!");
goto error;
}
@ -100,10 +142,12 @@ void OTAComponent::handle_() {
ESP_LOGV(TAG, "OTA features is 0x%02X", ota_features);
// Acknowledge header - 1 byte
this->client_.write(OTA_RESPONSE_HEADER_OK);
buf[0] = OTA_RESPONSE_HEADER_OK;
this->writeall_(buf, 1);
if (!this->password_.empty()) {
this->client_.write(OTA_RESPONSE_REQUEST_AUTH);
buf[0] = OTA_RESPONSE_REQUEST_AUTH;
this->writeall_(buf, 1);
MD5Builder md5_builder{};
md5_builder.begin();
sprintf(sbuf, "%08X", random_uint32());
@ -113,7 +157,7 @@ void OTAComponent::handle_() {
ESP_LOGV(TAG, "Auth: Nonce is %s", sbuf);
// Send nonce, 32 bytes hex MD5
if (this->client_.write(reinterpret_cast<uint8_t *>(sbuf), 32) != 32) {
if (!this->writeall_(reinterpret_cast<uint8_t *>(sbuf), 32)) {
ESP_LOGW(TAG, "Auth: Writing nonce failed!");
goto error;
}
@ -125,7 +169,7 @@ void OTAComponent::handle_() {
md5_builder.add(sbuf);
// Receive cnonce, 32 bytes hex MD5
if (!this->wait_receive_(buf, 32)) {
if (!this->readall_(buf, 32)) {
ESP_LOGW(TAG, "Auth: Reading cnonce failed!");
goto error;
}
@ -140,7 +184,7 @@ void OTAComponent::handle_() {
ESP_LOGV(TAG, "Auth: Result is %s", sbuf);
// Receive result, 32 bytes hex MD5
if (!this->wait_receive_(buf + 64, 32)) {
if (!this->writeall_(buf + 64, 32)) {
ESP_LOGW(TAG, "Auth: Reading response failed!");
goto error;
}
@ -159,10 +203,11 @@ void OTAComponent::handle_() {
}
// Acknowledge auth OK - 1 byte
this->client_.write(OTA_RESPONSE_AUTH_OK);
buf[0] = OTA_RESPONSE_AUTH_OK;
this->writeall_(buf, 1);
// Read size, 4 bytes MSB first
if (!this->wait_receive_(buf, 4)) {
if (!this->readall_(buf, 4)) {
ESP_LOGW(TAG, "Reading size failed!");
goto error;
}
@ -211,10 +256,11 @@ void OTAComponent::handle_() {
update_started = true;
// Acknowledge prepare OK - 1 byte
this->client_.write(OTA_RESPONSE_UPDATE_PREPARE_OK);
buf[0] = OTA_RESPONSE_UPDATE_PREPARE_OK;
this->writeall_(buf, 1);
// Read binary MD5, 32 bytes
if (!this->wait_receive_(buf, 32)) {
if (!this->readall_(buf, 32)) {
ESP_LOGW(TAG, "Reading binary MD5 checksum failed!");
goto error;
}
@ -223,17 +269,22 @@ void OTAComponent::handle_() {
Update.setMD5(sbuf);
// Acknowledge MD5 OK - 1 byte
this->client_.write(OTA_RESPONSE_BIN_MD5_OK);
buf[0] = OTA_RESPONSE_BIN_MD5_OK;
this->writeall_(buf, 1);
while (!Update.isFinished()) {
size_t available = this->wait_receive_(buf, 0);
if (!available) {
// TODO: timeout check
ssize_t read = this->client_->read(buf, sizeof(buf));
if (read == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK)
continue;
ESP_LOGW(TAG, "Error receiving data for update, errno: %d", errno);
goto error;
}
uint32_t written = Update.write(buf, available);
if (written != available) {
ESP_LOGW(TAG, "Error writing binary data to flash: %u != %u!", written, available); // NOLINT
uint32_t written = Update.write(buf, read);
if (written != read) {
ESP_LOGW(TAG, "Error writing binary data to flash: %u != %u!", written, read); // NOLINT
error_code = OTA_RESPONSE_ERROR_WRITING_FLASH;
goto error;
}
@ -253,7 +304,8 @@ void OTAComponent::handle_() {
}
// Acknowledge receive OK - 1 byte
this->client_.write(OTA_RESPONSE_RECEIVE_OK);
buf[0] = OTA_RESPONSE_RECEIVE_OK;
this->writeall_(buf, 1);
if (!Update.end()) {
error_code = OTA_RESPONSE_ERROR_UPDATE_END;
@ -261,16 +313,17 @@ void OTAComponent::handle_() {
}
// Acknowledge Update end OK - 1 byte
this->client_.write(OTA_RESPONSE_UPDATE_END_OK);
buf[0] = OTA_RESPONSE_UPDATE_END_OK;
this->writeall_(buf, 1);
// Read ACK
if (!this->wait_receive_(buf, 1, false) || buf[0] != OTA_RESPONSE_OK) {
if (!this->readall_(buf, 1) || buf[0] != OTA_RESPONSE_OK) {
ESP_LOGW(TAG, "Reading back acknowledgement failed!");
// do not go to error, this is not fatal
}
this->client_.flush();
this->client_.stop();
this->client_->close();
this->client_ = nullptr;
delay(10);
ESP_LOGI(TAG, "OTA update finished!");
this->status_clear_warning();
@ -286,11 +339,10 @@ error:
Update.printError(ss);
ESP_LOGW(TAG, "Update end failed! Error: %s", ss.c_str());
}
if (this->client_.connected()) {
this->client_.write(static_cast<uint8_t>(error_code));
this->client_.flush();
}
this->client_.stop();
buf[0] = static_cast<uint8_t>(error_code);
this->writeall_(buf, 1);
this->client_->close();
this->client_ = nullptr;
#ifdef ARDUINO_ARCH_ESP32
if (update_started) {
@ -314,52 +366,56 @@ error:
#endif
}
size_t OTAComponent::wait_receive_(uint8_t *buf, size_t bytes, bool check_disconnected) {
size_t available = 0;
bool OTAComponent::readall_(uint8_t *buf, size_t len) {
uint32_t start = millis();
do {
App.feed_wdt();
if (check_disconnected && !this->client_.connected()) {
ESP_LOGW(TAG, "Error client disconnected while receiving data!");
return 0;
}
int availi = this->client_.available();
if (availi < 0) {
ESP_LOGW(TAG, "Error reading data!");
return 0;
}
uint32_t at = 0;
while (len - at > 0) {
uint32_t now = millis();
if (availi == 0 && now - start > 10000) {
ESP_LOGW(TAG, "Timeout waiting for data!");
return 0;
if (now - start > 1000) {
ESP_LOGW(TAG, "Timed out reading %d bytes of data", len);
return false;
}
available = size_t(availi);
yield();
} while (bytes == 0 ? available == 0 : available < bytes);
if (bytes == 0)
bytes = std::min(available, size_t(1024));
bool success = false;
for (uint32_t i = 0; !success && i < 100; i++) {
int res = this->client_.read(buf, bytes);
if (res != int(bytes)) {
// ESP32 implementation has an issue where calling read can fail with EAGAIN (race condition)
// so just re-try it until it works (with generous timeout of 1s)
// because we check with available() first this should not cause us any trouble in all other cases
delay(10);
ssize_t read = this->client_->read(buf + at, len - at);
if (read == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
delay(1);
continue;
}
ESP_LOGW(TAG, "Failed to read %d bytes of data, errno: %d", len, errno);
return false;
} else {
success = true;
at += read;
}
delay(1);
}
if (!success) {
ESP_LOGW(TAG, "Reading %u bytes of binary data failed!", bytes); // NOLINT
return 0;
}
return true;
}
bool OTAComponent::writeall_(const uint8_t *buf, size_t len) {
uint32_t start = millis();
uint32_t at = 0;
while (len - at > 0) {
uint32_t now = millis();
if (now - start > 1000) {
ESP_LOGW(TAG, "Timed out writing %d bytes of data", len);
return false;
}
return bytes;
ssize_t written = this->client_->write(buf + at, len - at);
if (written == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
delay(1);
continue;
}
ESP_LOGW(TAG, "Failed to write %d bytes of data, errno: %d", len, errno);
return false;
} else {
at += written;
}
delay(1);
}
return true;
}
void OTAComponent::set_auth_password(const std::string &password) { this->password_ = password; }

View file

@ -3,8 +3,7 @@
#include "esphome/core/component.h"
#include "esphome/core/preferences.h"
#include "esphome/core/helpers.h"
#include <WiFiServer.h>
#include <WiFiClient.h>
#include "esphome/components/socket/socket.h"
namespace esphome {
namespace ota {
@ -74,14 +73,15 @@ class OTAComponent : public Component {
uint32_t read_rtc_();
void handle_();
size_t wait_receive_(uint8_t *buf, size_t bytes, bool check_disconnected = true);
bool readall_(uint8_t *buf, size_t len);
bool writeall_(const uint8_t *buf, size_t len);
std::string password_;
uint16_t port_;
WiFiServer *server_{nullptr};
WiFiClient client_{};
std::unique_ptr<socket::Socket> server_;
std::unique_ptr<socket::Socket> client_;
bool has_safe_mode_{false}; ///< stores whether safe mode can be enabled.
uint32_t safe_mode_start_time_; ///< stores when safe mode was enabled.

View file

@ -0,0 +1,2 @@
# Dummy package to allow components to depend on network
CODEOWNERS = ["@esphome/core"]

View file

@ -0,0 +1,45 @@
#pragma once
#include <string>
#include <sys/types.h>
#include <sys/socket.h>
#include <memory>
#include "esphome/core/optional.h"
namespace esphome {
namespace socket {
using socklen_t = uint32_t;
class Socket {
public:
Socket() = default;
virtual ~Socket() = default;
Socket(const Socket&) = delete;
Socket &operator=(const Socket &) = delete;
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
virtual int close() = 0;
virtual int connect(const std::string &address) = 0;
virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
virtual int shutdown(int how) = 0;
virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0;
virtual std::string getpeername() = 0;
virtual int getsockname(struct sockaddr *addr, socklen_t *addrlen) = 0;
virtual std::string getsockname() = 0;
virtual int getsockopt(int level, int optname, void *optval, socklen_t *optlen) = 0;
virtual int setsockopt(int level, int optname, const void *optval, socklen_t optlen) = 0;
virtual int listen(int backlog) = 0;
virtual ssize_t read(void *buf, size_t len) = 0;
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
virtual ssize_t write(const void *buf, size_t len) = 0;
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
virtual int setblocking(bool blocking) = 0;
};
std::unique_ptr<Socket> socket(int domain, int type, int protocol);
} // socket
} // esphome

View file

@ -0,0 +1,118 @@
#include "socket.h"
#include <unistd.h>
#include <fcntl.h>
#include <string.h>
namespace esphome {
namespace socket {
std::string format_sockaddr(const struct sockaddr_storage &storage) {
if (storage.ss_family == AF_INET) {
const struct sockaddr_in *addr = reinterpret_cast<const struct sockaddr_in *>(&storage);
char buf[INET_ADDRSTRLEN];
const char *ret = inet_ntop(AF_INET, &addr->sin_addr, buf, sizeof(buf));
if (ret == NULL)
return {};
return std::string{buf};
} else if (storage.ss_family == AF_INET6) {
const struct sockaddr_in6 *addr = reinterpret_cast<const struct sockaddr_in6 *>(&storage);
char buf[INET6_ADDRSTRLEN];
const char *ret = inet_ntop(AF_INET6, &addr->sin6_addr, buf, sizeof(buf));
if (ret == NULL)
return {};
return std::string{buf};
}
return {};
}
class SocketImplSocket : public Socket {
public:
SocketImplSocket(int fd) : Socket(), fd_(fd) {}
~SocketImplSocket() override {
if (!closed_) {
close();
}
}
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
int fd = ::accept(fd_, addr, addrlen);
if (fd == -1)
return {};
return std::unique_ptr<SocketImplSocket>{new SocketImplSocket(fd)};
}
int bind(const struct sockaddr *addr, socklen_t addrlen) override {
return ::bind(fd_, addr, addrlen);
}
int close() override {
int ret = ::close(fd_);
closed_ = true;
return ret;
}
int connect(const std::string &address) override {
// TODO
return 0;
}
int connect(const struct sockaddr *addr, socklen_t addrlen) override {
return ::connect(fd_, addr, addrlen);
}
int shutdown(int how) override {
return ::shutdown(fd_, how);
}
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override {
return ::getpeername(fd_, addr, addrlen);
}
std::string getpeername() override {
struct sockaddr_storage storage;
socklen_t len = sizeof(storage);
int err = this->getpeername((struct sockaddr *) &storage, &len);
if (err != 0)
return {};
return format_sockaddr(storage);
}
int getsockname(struct sockaddr *addr, socklen_t *addrlen) override {
return ::getsockname(fd_, addr, addrlen);
}
std::string getsockname() override {
struct sockaddr_storage storage;
socklen_t len = sizeof(storage);
int err = this->getsockname((struct sockaddr *) &storage, &len);
if (err != 0)
return {};
return format_sockaddr(storage);
}
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
return ::getsockopt(fd_, level, optname, optval, optlen);
}
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
return ::setsockopt(fd_, level, optname, optval, optlen);
}
int listen(int backlog) override {
return ::listen(fd_, backlog);
}
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
ssize_t read(void *buf, size_t len) override {
return ::read(fd_, buf, len);
}
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
ssize_t write(const void *buf, size_t len) override {
return ::write(fd_, buf, len);
}
int setblocking(bool blocking) override {
int fl = ::fcntl(fd_, F_GETFL, 0);
::fcntl(fd_, F_SETFL, fl | O_NONBLOCK);
return 0;
}
protected:
int fd_;
bool closed_ = false;
};
std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
int ret = ::socket(domain, type, protocol);
if (ret == -1)
return nullptr;
return std::unique_ptr<Socket>{new SocketImplSocket(ret)};
}
} // namespace socket
} // namespace esphome

View file

@ -0,0 +1 @@
AUTO_LOAD = ["socket"]

View file

@ -0,0 +1,233 @@
#include "ssl_context.h"
#include <string.h>
#include "mbedtls/platform.h"
#include "mbedtls/net_sockets.h"
#include "mbedtls/esp_debug.h"
#include "mbedtls/ssl.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/error.h"
#include "mbedtls/certs.h"
namespace esphome {
namespace ssl {
static int entropy_hw_random_source(void *data, uint8_t *output, size_t len, size_t *olen) {
esp_fill_random(output, len);
*olen = len;
return 0;
}
struct MbedTLSBioCtx {
socket::Socket *sock;
static int send(void *raw, const uint8_t *buf, size_t len) {
auto *ctx = reinterpret_cast<MbedTLSBioCtx *>(raw);
ssize_t ret = ctx->sock->write(buf, len);
if (ret != -1)
return ret;
if (errno == EWOULDBLOCK || errno == EAGAIN)
return MBEDTLS_ERR_SSL_WANT_WRITE;
if (errno == EPIPE || errno == ECONNRESET)
return MBEDTLS_ERR_NET_CONN_RESET;
if (errno == EINTR)
return MBEDTLS_ERR_SSL_WANT_WRITE;
return MBEDTLS_ERR_NET_SEND_FAILED;
}
static int recv(void *raw, uint8_t *buf, size_t len) {
auto *ctx = reinterpret_cast<MbedTLSBioCtx *>(raw);
ssize_t ret = ctx->sock->read(buf, len);
if (ret != -1)
return ret;
if (errno == EWOULDBLOCK || errno == EAGAIN)
return MBEDTLS_ERR_SSL_WANT_WRITE;
if (errno == EPIPE || errno == ECONNRESET)
return MBEDTLS_ERR_NET_CONN_RESET;
if (errno == EINTR)
return MBEDTLS_ERR_SSL_WANT_WRITE;
return MBEDTLS_ERR_NET_SEND_FAILED;
}
};
class MbedTLSWrappedSocket : public socket::Socket {
public:
MbedTLSWrappedSocket(std::unique_ptr<socket::Socket> sock)
: socket::Socket(), sock_(std::move(sock)) {}
~MbedTLSWrappedSocket() override {
mbedtls_ssl_free(&ssl_);
sock_ = nullptr;
}
void init(const mbedtls_ssl_config *conf) {
// TODO: reuse ssl contexts?
mbedtls_ssl_init(&ssl_);
mbedtls_ssl_setup(&ssl_, conf);
// sock pointer does not fit in void*
// instead store it in a heap-allocated var
auto *ctx = new MbedTLSBioCtx;
// unsafe, but should be fine because we free before sock is reset
ctx->sock = sock_.get();
mbedtls_ssl_set_bio(&ssl_, ctx, MbedTLSBioCtx::send, MbedTLSBioCtx::recv, nullptr);
do_handshake_ = true;
}
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
// only for server sockets
errno = EBADF;
return {};
}
int bind(const struct sockaddr *addr, socklen_t addrlen) override {
errno = EBADF;
return -1;
}
int close() override {
return sock_->close();
}
int connect(const std::string &address) override {
return sock_->connect(address);
}
int connect(const struct sockaddr *addr, socklen_t addrlen) override {
return sock_->connect(addr, addrlen);
}
int shutdown(int how) override {
int ret = mbedtls_ssl_close_notify(&ssl_);
if (ret != 0)
return this->mbedtls_to_errno_(ret);
return this->sock_->shutdown(how);
}
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override {
return sock_->getpeername(addr, addrlen);
}
std::string getpeername() override {
return sock_->getpeername();
}
int getsockname(struct sockaddr *addr, socklen_t *addrlen) override {
return sock_->getsockname(addr, addrlen);
}
std::string getsockname() override {
return sock_->getsockname();
}
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
return sock_->getsockopt(level, optname, optval, optlen);
}
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
return sock_->setsockopt(level, optname, optval, optlen);
}
int listen(int backlog) override {
errno = EBADF;
return -1;
}
ssize_t read(void *buf, size_t len) override {
// mbedtls will automatically perform handshake here if necessary
int ret = mbedtls_ssl_read(&ssl_, reinterpret_cast<uint8_t *>(buf), len);
return this->mbedtls_to_errno_(ret);
}
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
ssize_t write(const void *buf, size_t len) override {
int ret = mbedtls_ssl_write(&ssl_, reinterpret_cast<const uint8_t *>(buf), len);
return this->mbedtls_to_errno_(ret);
}
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
int setblocking(bool blocking) override {
// TODO: handle blocking modes
return sock_->setblocking(blocking);
}
protected:
int mbedtls_to_errno_(int ret) {
if (ret > 0) {
return ret;
} else if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
errno = EWOULDBLOCK;
return -1;
} else if (ret == MBEDTLS_ERR_NET_CONN_RESET) {
errno = ECONNRESET;
return -1;
} else {
errno = EIO;
return -1;
}
}
std::unique_ptr<socket::Socket> sock_;
mbedtls_ssl_context ssl_;
bool do_handshake_ = false;
};
class MbedTLSContext : public SSLContext {
public:
MbedTLSContext() = default;
~MbedTLSContext() override {
mbedtls_pk_free(&pkey_);
mbedtls_entropy_free(&entropy_);
mbedtls_ctr_drbg_free(&ctr_drbg_);
mbedtls_x509_crt_free(&srv_cert_);
mbedtls_ssl_config_free(&conf_);
}
void set_server_certificate(const char *cert) override {
this->srv_cert_str_ = cert;
}
void set_private_key(const char *private_key) override {
this->privkey_str_ = private_key;
}
int init() override {
mbedtls_x509_crt_init(&srv_cert_);
mbedtls_ctr_drbg_init(&ctr_drbg_);
mbedtls_entropy_init(&entropy_);
mbedtls_pk_init(&pkey_);
mbedtls_ssl_config_init(&conf_);
// TODO check what this does
mbedtls_entropy_add_source(&entropy_, entropy_hw_random_source, NULL, 134, MBEDTLS_ENTROPY_SOURCE_STRONG);
mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, NULL, 0);
mbedtls_x509_crt_parse(
&srv_cert_,
reinterpret_cast<const uint8_t *>(srv_cert_str_),
strlen(srv_cert_str_)
);
mbedtls_pk_parse_key(
&pkey_,
reinterpret_cast<const uint8_t *>(privkey_str_),
strlen(privkey_str_),
nullptr,
0
);
mbedtls_ssl_config_defaults(
&conf_,
MBEDTLS_SSL_IS_SERVER,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT
);
mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_);
mbedtls_ssl_conf_own_cert(&conf_, &srv_cert_, &pkey_);
return 0;
}
std::unique_ptr<socket::Socket> wrap_socket(std::unique_ptr<socket::Socket> sock) override {
auto *wrapped = new MbedTLSWrappedSocket(std::move(sock));
wrapped->init(&conf_);
return std::unique_ptr<socket::Socket>{wrapped};
}
protected:
const char *srv_cert_str_ = nullptr;
const char *privkey_str_ = nullptr;
mbedtls_entropy_context entropy_;
mbedtls_ctr_drbg_context ctr_drbg_;
mbedtls_x509_crt srv_cert_;
mbedtls_pk_context pkey_;
mbedtls_ssl_config conf_;
};
std::unique_ptr<SSLContext> create_context() {
return std::unique_ptr<SSLContext>{new MbedTLSContext()};
}
} // namespace ssl
} // namespace esphome

View file

@ -0,0 +1,24 @@
#pragma once
#include <memory>
#include "esphome/components/socket/socket.h"
namespace esphome {
namespace ssl {
class SSLContext {
public:
SSLContext() = default;
virtual ~SSLContext() = default;
SSLContext(const SSLContext&) = delete;
SSLContext &operator=(const SSLContext &) = delete;
virtual int init() = 0;
virtual void set_server_certificate(const char *cert) = 0;
virtual void set_private_key(const char *private_key) = 0;
virtual std::unique_ptr<socket::Socket> wrap_socket(std::unique_ptr<socket::Socket> sock) = 0;
};
std::unique_ptr<SSLContext> create_context();
} // namespace ssl
} // namespace esphome