mirror of
https://github.com/esphome/esphome.git
synced 2024-11-09 16:57:47 +01:00
Complete
This commit is contained in:
parent
44041d2526
commit
88632b22e2
12 changed files with 470 additions and 71 deletions
|
@ -120,6 +120,9 @@ async def to_code(config):
|
|||
conf = config[CONF_ENCRYPTION]
|
||||
decoded = base64.b64decode(conf[CONF_KEY])
|
||||
cg.add(var.set_noise_psk(list(decoded)))
|
||||
cg.add_define("USE_API_NOISE")
|
||||
else:
|
||||
cg.add_define("USE_API_PLAINTEXT")
|
||||
|
||||
cg.add_define("USE_API")
|
||||
cg.add_global(api_ns.using)
|
||||
|
|
|
@ -25,7 +25,14 @@ APIConnection::APIConnection(std::unique_ptr<socket::Socket> sock, APIServer *pa
|
|||
list_entities_iterator_(parent, this) {
|
||||
this->proto_write_buffer_.reserve(64);
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
helper_ = std::unique_ptr<APIFrameHelper>{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())};
|
||||
#elif defined(USE_API_PLAINTEXT)
|
||||
helper_ = std::unique_ptr<APIFrameHelper>{new APIPlaintextFrameHelper(std::move(sock))};
|
||||
#else
|
||||
#error "No api frame helper enabled"
|
||||
#endif
|
||||
|
||||
}
|
||||
void APIConnection::start() {
|
||||
this->last_traffic_ = millis();
|
||||
|
|
|
@ -2,16 +2,12 @@
|
|||
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "proto.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
static const char *const TAG = "api.socket";
|
||||
static const char *const PROLOGUE_INIT = "NoiseAPIInit";
|
||||
|
||||
// TODO:
|
||||
// - track errors internally and return if in bad state
|
||||
// - send error on invalid psk
|
||||
|
||||
/// Is the given return value (from read/write syscalls) a wouldblock error?
|
||||
bool is_would_block(ssize_t ret) {
|
||||
|
@ -21,7 +17,10 @@ bool is_would_block(ssize_t ret) {
|
|||
return ret == 0;
|
||||
}
|
||||
|
||||
#define HELPER_LOG(msg, ...) ESP_LOGW(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__)
|
||||
#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__)
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
static const char *const PROLOGUE_INIT = "NoiseAPIInit";
|
||||
|
||||
/// Convert a noise error code to a readable error
|
||||
std::string noise_err_to_str(int err) {
|
||||
|
@ -93,7 +92,15 @@ APIError APINoiseFrameHelper::loop() {
|
|||
APIError err = state_action_();
|
||||
if (err == APIError::WOULD_BLOCK)
|
||||
return APIError::OK;
|
||||
return err;
|
||||
if (err != APIError::OK)
|
||||
return err;
|
||||
if (!tx_buf_.empty()) {
|
||||
err = try_send_tx_buf_();
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
|
@ -180,7 +187,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) {
|
|||
}
|
||||
}
|
||||
|
||||
// ESP_LOGD(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
|
||||
// uncomment for even more debugging
|
||||
// ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
|
||||
frame->msg = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
|
@ -239,17 +247,33 @@ APIError APINoiseFrameHelper::state_action_() {
|
|||
// waiting for handshake msg
|
||||
ParsedFrame frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr == APIError::BAD_INDICATOR) {
|
||||
send_explicit_handshake_reject_("Bad indicator byte");
|
||||
return aerr;
|
||||
}
|
||||
if (frame.msg.size() < 1 || frame.msg[0] != 0x00) {
|
||||
aerr = APIError::BAD_HANDSHAKE_PACKET_LEN;
|
||||
}
|
||||
if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) {
|
||||
send_explicit_handshake_reject_("Bad handshake packet len");
|
||||
return aerr;
|
||||
}
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_input(mbuf, frame.msg.data(), frame.msg.size());
|
||||
noise_buffer_set_input(mbuf, frame.msg.data() + 1, frame.msg.size() - 1);
|
||||
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
|
||||
if (err != 0) {
|
||||
// TODO: explicit rejection
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_handshakestate_read_message failed: %s", noise_err_to_str(err).c_str());
|
||||
if (err == NOISE_ERROR_MAC_FAILURE) {
|
||||
send_explicit_handshake_reject_("Handshake MAC failure");
|
||||
} else {
|
||||
send_explicit_handshake_reject_("Handshake error");
|
||||
}
|
||||
return APIError::HANDSHAKESTATE_READ_FAILED;
|
||||
}
|
||||
|
||||
|
@ -257,10 +281,10 @@ APIError APINoiseFrameHelper::state_action_() {
|
|||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
} else if (action == NOISE_ACTION_WRITE_MESSAGE) {
|
||||
uint8_t buffer[64];
|
||||
uint8_t buffer[65];
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_output(mbuf, buffer, sizeof(buffer));
|
||||
noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1);
|
||||
|
||||
err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
|
||||
if (err != 0) {
|
||||
|
@ -268,7 +292,9 @@ APIError APINoiseFrameHelper::state_action_() {
|
|||
HELPER_LOG("noise_handshakestate_write_message failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::HANDSHAKESTATE_WRITE_FAILED;
|
||||
}
|
||||
aerr = write_frame_(mbuf.data, mbuf.size);
|
||||
buffer[0] = 0x00; // success
|
||||
|
||||
aerr = write_frame_(buffer, mbuf.size + 1);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
aerr = check_handshake_finished_();
|
||||
|
@ -286,6 +312,15 @@ APIError APINoiseFrameHelper::state_action_() {
|
|||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) {
|
||||
std::vector<uint8_t> data;
|
||||
data.reserve(reason.size() + 1);
|
||||
data[0] = 0x01; // failure
|
||||
for (size_t i = 0; i < reason.size(); i++) {
|
||||
data[i+1] = (uint8_t) reason[i];
|
||||
}
|
||||
write_frame_(data.data(), data.size());
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
int err;
|
||||
|
@ -397,7 +432,7 @@ APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload
|
|||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APINoiseFrameHelper::try_send_raw_() {
|
||||
APIError APINoiseFrameHelper::try_send_tx_buf_() {
|
||||
// try send from tx_buf
|
||||
while (state_ != State::CLOSED && !tx_buf_.empty()) {
|
||||
ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size());
|
||||
|
@ -428,11 +463,12 @@ APIError APINoiseFrameHelper::write_raw_(const uint8_t *data, size_t len) {
|
|||
int err;
|
||||
APIError aerr;
|
||||
|
||||
// ESP_LOGD(TAG, "Sending raw: %s", hexencode(data, len).c_str());
|
||||
// uncomment for even more debugging
|
||||
// ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str());
|
||||
|
||||
if (!tx_buf_.empty()) {
|
||||
// try to empty tx_buf_ first
|
||||
aerr = try_send_raw_();
|
||||
aerr = try_send_tx_buf_();
|
||||
if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK)
|
||||
return aerr;
|
||||
}
|
||||
|
@ -582,15 +618,290 @@ APIError APINoiseFrameHelper::shutdown(int how) {
|
|||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
|
||||
extern "C" {
|
||||
|
||||
// declare how noise generates random bytes (here with a good HWRNG based on the RF system)
|
||||
void noise_rand_bytes(void *output, size_t len) {
|
||||
esphome::fill_random(reinterpret_cast<uint8_t *>(output), len);
|
||||
}
|
||||
|
||||
}
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
|
||||
#ifdef USE_API_PLAINTEXT
|
||||
|
||||
/// Initialize the frame helper, returns OK if successful.
|
||||
APIError APIPlaintextFrameHelper::init() {
|
||||
if (state_ != State::INITIALIZE || socket_ == nullptr) {
|
||||
HELPER_LOG("Bad state for init %d", (int) state_);
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
int err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Setting nonblocking failed with errno %d", errno);
|
||||
return APIError::TCP_NONBLOCKING_FAILED;
|
||||
}
|
||||
int enable = 1;
|
||||
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Setting nodelay failed with errno %d", errno);
|
||||
return APIError::TCP_NODELAY_FAILED;
|
||||
}
|
||||
|
||||
state_ = State::DATA;
|
||||
return APIError::OK;
|
||||
}
|
||||
/// Not used for plaintext
|
||||
APIError APIPlaintextFrameHelper::loop() {
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
// try send pending TX data
|
||||
if (!tx_buf_.empty()) {
|
||||
APIError err = try_send_tx_buf_();
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
*
|
||||
* @param frame: The struct to hold the frame information in.
|
||||
* msg: store the parsed frame in that struct
|
||||
*
|
||||
* @return See APIError
|
||||
*
|
||||
* error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
|
||||
*/
|
||||
APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
if (frame == nullptr) {
|
||||
HELPER_LOG("Bad argument for try_read_frame_");
|
||||
return APIError::BAD_ARG;
|
||||
}
|
||||
|
||||
// read header
|
||||
while (!rx_header_parsed_) {
|
||||
uint8_t data;
|
||||
ssize_t received = socket_->read(&data, 1);
|
||||
if (is_would_block(received)) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
} else if (received == -1) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket read failed with errno %d", errno);
|
||||
return APIError::SOCKET_READ_FAILED;
|
||||
}
|
||||
rx_header_buf_.push_back(data);
|
||||
|
||||
// try parse header
|
||||
if (rx_header_buf_[0] != 0x00) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
|
||||
return APIError::BAD_INDICATOR;
|
||||
}
|
||||
|
||||
size_t i = 1;
|
||||
size_t consumed = 0;
|
||||
auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed);
|
||||
if (!msg_size_varint.has_value()) {
|
||||
// not enough data there yet
|
||||
continue;
|
||||
}
|
||||
|
||||
i += consumed;
|
||||
rx_header_parsed_len_ = msg_size_varint->as_uint32();
|
||||
|
||||
auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed);
|
||||
if (!msg_type_varint.has_value()) {
|
||||
// not enough data there yet
|
||||
continue;
|
||||
}
|
||||
rx_header_parsed_type_ = msg_type_varint->as_uint32();
|
||||
rx_header_parsed_ = true;
|
||||
}
|
||||
// header reading done
|
||||
|
||||
// reserve space for body
|
||||
if (rx_buf_.size() != rx_header_parsed_len_) {
|
||||
rx_buf_.resize(rx_header_parsed_len_);
|
||||
}
|
||||
|
||||
if (rx_buf_len_ < rx_header_parsed_len_) {
|
||||
// more data to read
|
||||
size_t to_read = rx_header_parsed_len_ - rx_buf_len_;
|
||||
ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read);
|
||||
if (is_would_block(received)) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
} else if (received == -1) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket read failed with errno %d", errno);
|
||||
return APIError::SOCKET_READ_FAILED;
|
||||
}
|
||||
rx_buf_len_ += received;
|
||||
if (received != to_read) {
|
||||
// not all read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
// uncomment for even more debugging
|
||||
// ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
|
||||
frame->msg = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
rx_buf_len_ = 0;
|
||||
rx_header_buf_.clear();
|
||||
rx_header_parsed_ = false;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
ParsedFrame frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
buffer->container = std::move(frame.msg);
|
||||
buffer->data_offset = 0;
|
||||
buffer->data_len = rx_header_parsed_len_;
|
||||
buffer->type = rx_header_parsed_type_;
|
||||
return APIError::OK;
|
||||
}
|
||||
bool APIPlaintextFrameHelper::can_write_without_blocking() {
|
||||
return state_ == State::DATA && tx_buf_.empty();
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> header;
|
||||
header.push_back(0x00);
|
||||
ProtoVarInt(payload_len).encode(header);
|
||||
ProtoVarInt(type).encode(header);
|
||||
|
||||
aerr = write_raw_(&header[0], header.size());
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
aerr = write_raw_(payload, payload_len);
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::try_send_tx_buf_() {
|
||||
// try send from tx_buf
|
||||
while (state_ != State::CLOSED && !tx_buf_.empty()) {
|
||||
ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size());
|
||||
if (sent == -1) {
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN)
|
||||
break;
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket write failed with errno %d", errno);
|
||||
return APIError::SOCKET_WRITE_FAILED;
|
||||
} else if (sent == 0) {
|
||||
break;
|
||||
}
|
||||
// TODO: inefficient if multiple packets in txbuf
|
||||
// replace with deque of buffers
|
||||
tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent);
|
||||
}
|
||||
|
||||
return APIError::OK;
|
||||
}
|
||||
/** Write the data to the socket, or buffer it a write would block
|
||||
*
|
||||
* @param data The data to write
|
||||
* @param len The length of data
|
||||
*/
|
||||
APIError APIPlaintextFrameHelper::write_raw_(const uint8_t *data, size_t len) {
|
||||
if (len == 0)
|
||||
return APIError::OK;
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
// uncomment for even more debugging
|
||||
// ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str());
|
||||
|
||||
if (!tx_buf_.empty()) {
|
||||
// try to empty tx_buf_ first
|
||||
aerr = try_send_tx_buf_();
|
||||
if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK)
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (!tx_buf_.empty()) {
|
||||
// tx buf not empty, can't write now because then stream would be inconsistent
|
||||
tx_buf_.insert(tx_buf_.end(), data, data + len);
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
ssize_t sent = socket_->write(data, len);
|
||||
if (is_would_block(sent)) {
|
||||
// operation would block, add buffer to tx_buf
|
||||
tx_buf_.insert(tx_buf_.end(), data, data + len);
|
||||
return APIError::OK;
|
||||
} else if (sent == -1) {
|
||||
// an error occured
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket write failed with errno %d", errno);
|
||||
return APIError::SOCKET_WRITE_FAILED;
|
||||
} else if (sent != len) {
|
||||
// partially sent, add end to tx_buf
|
||||
tx_buf_.insert(tx_buf_.end(), data + sent, data + len);
|
||||
return APIError::OK;
|
||||
}
|
||||
// fully sent
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::write_frame_(const uint8_t *data, size_t len) {
|
||||
APIError aerr;
|
||||
|
||||
uint8_t header[3];
|
||||
header[0] = 0x01; // indicator
|
||||
header[1] = (uint8_t) (len >> 8);
|
||||
header[2] = (uint8_t) len;
|
||||
|
||||
aerr = write_raw_(header, 3);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
aerr = write_raw_(data, len);
|
||||
return aerr;
|
||||
}
|
||||
|
||||
APIError APIPlaintextFrameHelper::close() {
|
||||
state_ = State::CLOSED;
|
||||
int err = socket_->close();
|
||||
if (err == -1)
|
||||
return APIError::CLOSE_FAILED;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::shutdown(int how) {
|
||||
int err = socket_->shutdown(how);
|
||||
if (err == -1)
|
||||
return APIError::SHUTDOWN_FAILED;
|
||||
if (how == SHUT_RDWR) {
|
||||
state_ = State::CLOSED;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
#endif // USE_API_PLAINTEXT
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
|
|
|
@ -3,7 +3,11 @@
|
|||
#include <vector>
|
||||
#include <deque>
|
||||
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
#include "noise/protocol.h"
|
||||
#endif
|
||||
|
||||
#include "esphome/components/socket/socket.h"
|
||||
#include "api_noise_context.h"
|
||||
|
@ -63,39 +67,39 @@ class APIFrameHelper {
|
|||
virtual void set_log_info(std::string info) = 0;
|
||||
};
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
class APINoiseFrameHelper : public APIFrameHelper {
|
||||
public:
|
||||
APINoiseFrameHelper(std::unique_ptr<socket::Socket> socket, std::shared_ptr<APINoiseContext> ctx) : socket_(std::move(socket)), ctx_(ctx) {}
|
||||
~APINoiseFrameHelper();
|
||||
APIError init();
|
||||
APIError loop();
|
||||
APIError read_packet(ReadPacketBuffer *buffer);
|
||||
bool can_write_without_blocking();
|
||||
APIError write_packet(uint16_t type, const uint8_t *data, size_t len);
|
||||
std::string getpeername() {
|
||||
APIError init() override;
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
bool can_write_without_blocking() override;
|
||||
APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override;
|
||||
std::string getpeername() override{
|
||||
return socket_->getpeername();
|
||||
}
|
||||
APIError close();
|
||||
APIError shutdown(int how);
|
||||
APIError close() override;
|
||||
APIError shutdown(int how) override;
|
||||
// Give this helper a name for logging
|
||||
void set_log_info(std::string info) {
|
||||
void set_log_info(std::string info) override {
|
||||
info_ = std::move(info);
|
||||
}
|
||||
|
||||
protected:
|
||||
APIError reserve_rx_buf_(size_t new_capacity);
|
||||
|
||||
struct ParsedFrame {
|
||||
std::vector<uint8_t> msg;
|
||||
};
|
||||
|
||||
APIError state_action_();
|
||||
APIError try_read_frame_(ParsedFrame *frame);
|
||||
APIError try_send_raw_();
|
||||
APIError try_send_tx_buf_();
|
||||
APIError write_frame_(const uint8_t *data, size_t len);
|
||||
APIError write_raw_(const uint8_t *data, size_t len);
|
||||
APIError init_handshake_();
|
||||
APIError check_handshake_finished_();
|
||||
void send_explicit_handshake_reject_(const std::string &reason);
|
||||
|
||||
std::unique_ptr<socket::Socket> socket_;
|
||||
|
||||
|
@ -124,6 +128,59 @@ class APINoiseFrameHelper : public APIFrameHelper {
|
|||
FAILED = 7,
|
||||
} state_ = State::INITIALIZE;
|
||||
};
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
#ifdef USE_API_PLAINTEXT
|
||||
class APIPlaintextFrameHelper : public APIFrameHelper {
|
||||
public:
|
||||
APIPlaintextFrameHelper(std::unique_ptr<socket::Socket> socket) : socket_(std::move(socket)) {}
|
||||
~APIPlaintextFrameHelper() = default;
|
||||
APIError init() override;
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
bool can_write_without_blocking() override;
|
||||
APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override;
|
||||
std::string getpeername() override {
|
||||
return socket_->getpeername();
|
||||
}
|
||||
APIError close() override;
|
||||
APIError shutdown(int how) override;
|
||||
// Give this helper a name for logging
|
||||
void set_log_info(std::string info) override {
|
||||
info_ = std::move(info);
|
||||
}
|
||||
|
||||
protected:
|
||||
struct ParsedFrame {
|
||||
std::vector<uint8_t> msg;
|
||||
};
|
||||
|
||||
APIError try_read_frame_(ParsedFrame *frame);
|
||||
APIError try_send_tx_buf_();
|
||||
APIError write_frame_(const uint8_t *data, size_t len);
|
||||
APIError write_raw_(const uint8_t *data, size_t len);
|
||||
|
||||
std::unique_ptr<socket::Socket> socket_;
|
||||
|
||||
std::string info_;
|
||||
std::vector<uint8_t> rx_header_buf_;
|
||||
bool rx_header_parsed_ = false;
|
||||
uint32_t rx_header_parsed_type_ = 0;
|
||||
uint32_t rx_header_parsed_len_ = 0;
|
||||
|
||||
std::vector<uint8_t> rx_buf_;
|
||||
size_t rx_buf_len_ = 0;
|
||||
|
||||
std::vector<uint8_t> tx_buf_;
|
||||
|
||||
enum class State {
|
||||
INITIALIZE = 1,
|
||||
DATA = 2,
|
||||
CLOSED = 3,
|
||||
FAILED = 4,
|
||||
} state_ = State::INITIALIZE;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
#pragma once
|
||||
#include <cstdint>
|
||||
#include <array>
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
using psk_t = std::array<uint8_t, 32>;
|
||||
|
||||
class APINoiseContext {
|
||||
|
@ -19,6 +21,7 @@ class APINoiseContext {
|
|||
protected:
|
||||
psk_t psk_;
|
||||
};
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
|
|
|
@ -31,12 +31,14 @@ class APIServer : public Component, public Controller {
|
|||
void set_password(const std::string &password);
|
||||
void set_reboot_timeout(uint32_t reboot_timeout);
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
void set_noise_psk(psk_t psk) {
|
||||
noise_ctx_->set_psk(std::move(psk));
|
||||
}
|
||||
std::shared_ptr<APINoiseContext> get_noise_ctx() {
|
||||
return noise_ctx_;
|
||||
}
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
void handle_disconnect(APIConnection *conn);
|
||||
#ifdef USE_BINARY_SENSOR
|
||||
|
@ -97,7 +99,10 @@ class APIServer : public Component, public Controller {
|
|||
std::string password_;
|
||||
std::vector<HomeAssistantStateSubscription> state_subs_;
|
||||
std::vector<UserServiceDescriptor *> user_services_;
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
std::shared_ptr<APINoiseContext> noise_ctx_ = std::make_shared<APINoiseContext>();
|
||||
#endif // USE_API_NOISE
|
||||
};
|
||||
|
||||
extern APIServer *global_api_server; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import esphome.config_validation as cv
|
||||
import esphome.codegen as cg
|
||||
|
||||
# Dummy package to allow components to depend on network
|
||||
CODEOWNERS = ["@esphome/core"]
|
||||
|
||||
CONF_IMPLEMENTATION = "implementation"
|
||||
|
|
|
@ -47,11 +47,6 @@ class BSDSocketImpl : public Socket {
|
|||
closed_ = true;
|
||||
return ret;
|
||||
}
|
||||
int connect(const std::string &address) override {
|
||||
// TODO
|
||||
return 0;
|
||||
}
|
||||
int connect(const struct sockaddr *addr, socklen_t addrlen) override { return ::connect(fd_, addr, addrlen); }
|
||||
int shutdown(int how) override { return ::shutdown(fd_, how); }
|
||||
|
||||
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { return ::getpeername(fd_, addr, addrlen); }
|
||||
|
@ -79,9 +74,7 @@ class BSDSocketImpl : public Socket {
|
|||
return ::setsockopt(fd_, level, optname, optval, optlen);
|
||||
}
|
||||
int listen(int backlog) override { return ::listen(fd_, backlog); }
|
||||
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
|
||||
ssize_t read(void *buf, size_t len) override { return ::read(fd_, buf, len); }
|
||||
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
|
||||
ssize_t write(const void *buf, size_t len) override { return ::write(fd_, buf, len); }
|
||||
int setblocking(bool blocking) override {
|
||||
int fl = ::fcntl(fd_, F_GETFL, 0);
|
||||
|
|
|
@ -29,7 +29,6 @@ class LWIPRawImpl : public Socket {
|
|||
}
|
||||
|
||||
void init() {
|
||||
ESP_LOGD(TAG, "init()");
|
||||
tcp_arg(pcb_, this);
|
||||
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
|
||||
tcp_recv(pcb_, LWIPRawImpl::s_recv_fn);
|
||||
|
@ -98,8 +97,7 @@ class LWIPRawImpl : public Socket {
|
|||
port = ntohs(addr4->sin_port);
|
||||
ip.addr = addr4->sin_addr.s_addr;
|
||||
#endif
|
||||
err_t err = tcp_bind(pcb_, IP4_ADDR_ANY, port);
|
||||
ESP_LOGD(TAG, "bind(ip=%u, port=%u) -> %d", ip.addr, port, err);
|
||||
err_t err = tcp_bind(pcb_, &ip, port);
|
||||
if (err == ERR_USE) {
|
||||
errno = EADDRINUSE;
|
||||
return -1;
|
||||
|
@ -129,14 +127,6 @@ class LWIPRawImpl : public Socket {
|
|||
pcb_ = nullptr;
|
||||
return 0;
|
||||
}
|
||||
int connect(const std::string &address) override {
|
||||
// TODO
|
||||
return -1;
|
||||
}
|
||||
int connect(const struct sockaddr *addr, socklen_t addrlen) override {
|
||||
// TODO
|
||||
return -1;
|
||||
}
|
||||
int shutdown(int how) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
|
@ -226,7 +216,29 @@ class LWIPRawImpl : public Socket {
|
|||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
// TODO
|
||||
if (level == SOL_SOCKET && optname == SO_REUSEADDR) {
|
||||
if (optlen < 4) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// lwip doesn't seem to have this feature. Don't send an error
|
||||
// to prevent warnings
|
||||
*reinterpret_cast<int *>(optval) = 1;
|
||||
*optlen = 4;
|
||||
return 0;
|
||||
}
|
||||
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
|
||||
if (optlen < 4) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
*reinterpret_cast<int *>(optval) = tcp_nagle_disabled(pcb_);
|
||||
*optlen = 4;
|
||||
return 0;
|
||||
}
|
||||
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
|
||||
|
@ -240,7 +252,8 @@ class LWIPRawImpl : public Socket {
|
|||
return -1;
|
||||
}
|
||||
|
||||
// TODO
|
||||
// lwip doesn't seem to have this feature. Don't send an error
|
||||
// to prevent warnings
|
||||
return 0;
|
||||
}
|
||||
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
|
||||
|
@ -266,7 +279,6 @@ class LWIPRawImpl : public Socket {
|
|||
return -1;
|
||||
}
|
||||
struct tcp_pcb *listen_pcb = tcp_listen_with_backlog(pcb_, backlog);
|
||||
ESP_LOGD(TAG, "listen(%d) -> %p", backlog, listen_pcb);
|
||||
if (listen_pcb == nullptr) {
|
||||
tcp_abort(pcb_);
|
||||
pcb_ = nullptr;
|
||||
|
@ -286,7 +298,7 @@ class LWIPRawImpl : public Socket {
|
|||
return -1;
|
||||
}
|
||||
if (rx_closed_ && rx_buf_ == nullptr) {
|
||||
errno = ECONNRESET; // TODO: is this the right errno?
|
||||
errno = ECONNRESET;
|
||||
return -1;
|
||||
}
|
||||
if (len == 0) {
|
||||
|
@ -333,7 +345,6 @@ class LWIPRawImpl : public Socket {
|
|||
|
||||
return read;
|
||||
}
|
||||
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
|
||||
ssize_t write(const void *buf, size_t len) {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
|
@ -367,7 +378,6 @@ class LWIPRawImpl : public Socket {
|
|||
}
|
||||
return to_send;
|
||||
}
|
||||
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
|
||||
int setblocking(bool blocking) {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
|
@ -382,20 +392,32 @@ class LWIPRawImpl : public Socket {
|
|||
}
|
||||
|
||||
err_t accept_fn(struct tcp_pcb *newpcb, err_t err) {
|
||||
// TODO: check err
|
||||
if (err != ERR_OK || newpcb == 0) {
|
||||
// "An error code if there has been an error accepting. Only return ERR_ABRT if you have
|
||||
// called tcp_abort from within the callback function!"
|
||||
// https://www.nongnu.org/lwip/2_1_x/tcp_8h.html#a00517abce6856d6c82f0efebdafb734d
|
||||
// nothing to do here, we just don't push it to the queue
|
||||
return ERR_OK;
|
||||
}
|
||||
accepted_sockets_.emplace(new LWIPRawImpl(newpcb));
|
||||
ESP_LOGD(TAG, "accept_fn newpcb=%p err=%d", newpcb, err);
|
||||
return ERR_OK;
|
||||
}
|
||||
void err_fn(err_t err) {
|
||||
ESP_LOGD(TAG, "err_fn err=%d", err);
|
||||
// "If a connection is aborted because of an error, the application is alerted of this event by
|
||||
// the err callback."
|
||||
// pcb is already freed when this callback is called
|
||||
// ERR_RST: connection was reset by remote host
|
||||
// ERR_ABRT: aborted through tcp_abort or TCP timer
|
||||
pcb_ = nullptr;
|
||||
}
|
||||
err_t recv_fn(struct pbuf *pb, err_t err) {
|
||||
// TODO: check err
|
||||
ESP_LOGD(TAG, "recv_fn pb=%p err=%d", pb, err);
|
||||
if (err != 0) {
|
||||
// "An error code if there has been an error receiving Only return ERR_ABRT if you have
|
||||
// called tcp_abort from within the callback function!"
|
||||
rx_closed_ = true;
|
||||
return ERR_OK;
|
||||
}
|
||||
if (pb == nullptr) {
|
||||
// remote host has closed the connection
|
||||
// TODO
|
||||
rx_closed_ = true;
|
||||
return ERR_OK;
|
||||
}
|
||||
|
|
|
@ -18,8 +18,9 @@ class Socket {
|
|||
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
virtual int close() = 0;
|
||||
virtual int connect(const std::string &address) = 0;
|
||||
virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
// not supported yet:
|
||||
// virtual int connect(const std::string &address) = 0;
|
||||
// virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
virtual int shutdown(int how) = 0;
|
||||
|
||||
virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
|
@ -30,9 +31,7 @@ class Socket {
|
|||
virtual int setsockopt(int level, int optname, const void *optval, socklen_t optlen) = 0;
|
||||
virtual int listen(int backlog) = 0;
|
||||
virtual ssize_t read(void *buf, size_t len) = 0;
|
||||
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
|
||||
virtual ssize_t write(const void *buf, size_t len) = 0;
|
||||
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
|
||||
virtual int setblocking(bool blocking) = 0;
|
||||
virtual int loop() { return 0; };
|
||||
};
|
||||
|
|
|
@ -130,7 +130,6 @@ class MbedTLSWrappedSocket : public socket::Socket {
|
|||
int ret = mbedtls_ssl_read(&ssl_, reinterpret_cast<uint8_t *>(buf), len);
|
||||
return this->mbedtls_to_errno_(ret);
|
||||
}
|
||||
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
|
||||
ssize_t write(const void *buf, size_t len) override {
|
||||
loop();
|
||||
if (do_handshake_) {
|
||||
|
@ -140,7 +139,6 @@ class MbedTLSWrappedSocket : public socket::Socket {
|
|||
int ret = mbedtls_ssl_write(&ssl_, reinterpret_cast<const uint8_t *>(buf), len);
|
||||
return this->mbedtls_to_errno_(ret);
|
||||
}
|
||||
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
|
||||
int setblocking(bool blocking) override {
|
||||
// TODO: handle blocking modes
|
||||
return sock_->setblocking(blocking);
|
||||
|
|
|
@ -31,3 +31,5 @@
|
|||
#define USE_MDNS
|
||||
#define USE_SOCKET_IMPL_LWIP_TCP
|
||||
#define USE_SOCKET_IMPL_BSD_SOCKETS
|
||||
#define USE_API_NOISE
|
||||
#define USE_API_PLAINTEXT
|
||||
|
|
Loading…
Reference in a new issue