This commit is contained in:
Otto winter 2021-09-06 12:44:53 +02:00
parent 44041d2526
commit 88632b22e2
No known key found for this signature in database
GPG key ID: 48ED2DDB96D7682C
12 changed files with 470 additions and 71 deletions

View file

@ -120,6 +120,9 @@ async def to_code(config):
conf = config[CONF_ENCRYPTION]
decoded = base64.b64decode(conf[CONF_KEY])
cg.add(var.set_noise_psk(list(decoded)))
cg.add_define("USE_API_NOISE")
else:
cg.add_define("USE_API_PLAINTEXT")
cg.add_define("USE_API")
cg.add_global(api_ns.using)

View file

@ -25,7 +25,14 @@ APIConnection::APIConnection(std::unique_ptr<socket::Socket> sock, APIServer *pa
list_entities_iterator_(parent, this) {
this->proto_write_buffer_.reserve(64);
#ifdef USE_API_NOISE
helper_ = std::unique_ptr<APIFrameHelper>{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())};
#elif defined(USE_API_PLAINTEXT)
helper_ = std::unique_ptr<APIFrameHelper>{new APIPlaintextFrameHelper(std::move(sock))};
#else
#error "No api frame helper enabled"
#endif
}
void APIConnection::start() {
this->last_traffic_ = millis();

View file

@ -2,16 +2,12 @@
#include "esphome/core/log.h"
#include "esphome/core/helpers.h"
#include "proto.h"
namespace esphome {
namespace api {
static const char *const TAG = "api.socket";
static const char *const PROLOGUE_INIT = "NoiseAPIInit";
// TODO:
// - track errors internally and return if in bad state
// - send error on invalid psk
/// Is the given return value (from read/write syscalls) a wouldblock error?
bool is_would_block(ssize_t ret) {
@ -21,7 +17,10 @@ bool is_would_block(ssize_t ret) {
return ret == 0;
}
#define HELPER_LOG(msg, ...) ESP_LOGW(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__)
#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__)
#ifdef USE_API_NOISE
static const char *const PROLOGUE_INIT = "NoiseAPIInit";
/// Convert a noise error code to a readable error
std::string noise_err_to_str(int err) {
@ -93,7 +92,15 @@ APIError APINoiseFrameHelper::loop() {
APIError err = state_action_();
if (err == APIError::WOULD_BLOCK)
return APIError::OK;
return err;
if (err != APIError::OK)
return err;
if (!tx_buf_.empty()) {
err = try_send_tx_buf_();
if (err != APIError::OK) {
return err;
}
}
return APIError::OK;
}
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
@ -180,7 +187,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) {
}
}
// ESP_LOGD(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
// uncomment for even more debugging
// ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
frame->msg = std::move(rx_buf_);
// consume msg
rx_buf_ = {};
@ -239,17 +247,33 @@ APIError APINoiseFrameHelper::state_action_() {
// waiting for handshake msg
ParsedFrame frame;
aerr = try_read_frame_(&frame);
if (aerr == APIError::BAD_INDICATOR) {
send_explicit_handshake_reject_("Bad indicator byte");
return aerr;
}
if (frame.msg.size() < 1 || frame.msg[0] != 0x00) {
aerr = APIError::BAD_HANDSHAKE_PACKET_LEN;
}
if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) {
send_explicit_handshake_reject_("Bad handshake packet len");
return aerr;
}
if (aerr != APIError::OK)
return aerr;
NoiseBuffer mbuf;
noise_buffer_init(mbuf);
noise_buffer_set_input(mbuf, frame.msg.data(), frame.msg.size());
noise_buffer_set_input(mbuf, frame.msg.data() + 1, frame.msg.size() - 1);
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
if (err != 0) {
// TODO: explicit rejection
state_ = State::FAILED;
HELPER_LOG("noise_handshakestate_read_message failed: %s", noise_err_to_str(err).c_str());
if (err == NOISE_ERROR_MAC_FAILURE) {
send_explicit_handshake_reject_("Handshake MAC failure");
} else {
send_explicit_handshake_reject_("Handshake error");
}
return APIError::HANDSHAKESTATE_READ_FAILED;
}
@ -257,10 +281,10 @@ APIError APINoiseFrameHelper::state_action_() {
if (aerr != APIError::OK)
return aerr;
} else if (action == NOISE_ACTION_WRITE_MESSAGE) {
uint8_t buffer[64];
uint8_t buffer[65];
NoiseBuffer mbuf;
noise_buffer_init(mbuf);
noise_buffer_set_output(mbuf, buffer, sizeof(buffer));
noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1);
err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
if (err != 0) {
@ -268,7 +292,9 @@ APIError APINoiseFrameHelper::state_action_() {
HELPER_LOG("noise_handshakestate_write_message failed: %s", noise_err_to_str(err).c_str());
return APIError::HANDSHAKESTATE_WRITE_FAILED;
}
aerr = write_frame_(mbuf.data, mbuf.size);
buffer[0] = 0x00; // success
aerr = write_frame_(buffer, mbuf.size + 1);
if (aerr != APIError::OK)
return aerr;
aerr = check_handshake_finished_();
@ -286,6 +312,15 @@ APIError APINoiseFrameHelper::state_action_() {
}
return APIError::OK;
}
void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) {
std::vector<uint8_t> data;
data.reserve(reason.size() + 1);
data[0] = 0x01; // failure
for (size_t i = 0; i < reason.size(); i++) {
data[i+1] = (uint8_t) reason[i];
}
write_frame_(data.data(), data.size());
}
APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
int err;
@ -397,7 +432,7 @@ APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload
}
return APIError::OK;
}
APIError APINoiseFrameHelper::try_send_raw_() {
APIError APINoiseFrameHelper::try_send_tx_buf_() {
// try send from tx_buf
while (state_ != State::CLOSED && !tx_buf_.empty()) {
ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size());
@ -428,11 +463,12 @@ APIError APINoiseFrameHelper::write_raw_(const uint8_t *data, size_t len) {
int err;
APIError aerr;
// ESP_LOGD(TAG, "Sending raw: %s", hexencode(data, len).c_str());
// uncomment for even more debugging
// ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str());
if (!tx_buf_.empty()) {
// try to empty tx_buf_ first
aerr = try_send_raw_();
aerr = try_send_tx_buf_();
if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK)
return aerr;
}
@ -582,15 +618,290 @@ APIError APINoiseFrameHelper::shutdown(int how) {
}
return APIError::OK;
}
} // namespace api
} // namespace esphome
extern "C" {
// declare how noise generates random bytes (here with a good HWRNG based on the RF system)
void noise_rand_bytes(void *output, size_t len) {
esphome::fill_random(reinterpret_cast<uint8_t *>(output), len);
}
}
#endif // USE_API_NOISE
#ifdef USE_API_PLAINTEXT
/// Initialize the frame helper, returns OK if successful.
APIError APIPlaintextFrameHelper::init() {
if (state_ != State::INITIALIZE || socket_ == nullptr) {
HELPER_LOG("Bad state for init %d", (int) state_);
return APIError::BAD_STATE;
}
int err = socket_->setblocking(false);
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("Setting nonblocking failed with errno %d", errno);
return APIError::TCP_NONBLOCKING_FAILED;
}
int enable = 1;
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("Setting nodelay failed with errno %d", errno);
return APIError::TCP_NODELAY_FAILED;
}
state_ = State::DATA;
return APIError::OK;
}
/// Not used for plaintext
APIError APIPlaintextFrameHelper::loop() {
if (state_ != State::DATA) {
return APIError::BAD_STATE;
}
// try send pending TX data
if (!tx_buf_.empty()) {
APIError err = try_send_tx_buf_();
if (err != APIError::OK) {
return err;
}
}
return APIError::OK;
}
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
*
* @param frame: The struct to hold the frame information in.
* msg: store the parsed frame in that struct
*
* @return See APIError
*
* error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
*/
APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) {
int err;
APIError aerr;
if (frame == nullptr) {
HELPER_LOG("Bad argument for try_read_frame_");
return APIError::BAD_ARG;
}
// read header
while (!rx_header_parsed_) {
uint8_t data;
ssize_t received = socket_->read(&data, 1);
if (is_would_block(received)) {
return APIError::WOULD_BLOCK;
} else if (received == -1) {
state_ = State::FAILED;
HELPER_LOG("Socket read failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
}
rx_header_buf_.push_back(data);
// try parse header
if (rx_header_buf_[0] != 0x00) {
state_ = State::FAILED;
HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
return APIError::BAD_INDICATOR;
}
size_t i = 1;
size_t consumed = 0;
auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed);
if (!msg_size_varint.has_value()) {
// not enough data there yet
continue;
}
i += consumed;
rx_header_parsed_len_ = msg_size_varint->as_uint32();
auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed);
if (!msg_type_varint.has_value()) {
// not enough data there yet
continue;
}
rx_header_parsed_type_ = msg_type_varint->as_uint32();
rx_header_parsed_ = true;
}
// header reading done
// reserve space for body
if (rx_buf_.size() != rx_header_parsed_len_) {
rx_buf_.resize(rx_header_parsed_len_);
}
if (rx_buf_len_ < rx_header_parsed_len_) {
// more data to read
size_t to_read = rx_header_parsed_len_ - rx_buf_len_;
ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read);
if (is_would_block(received)) {
return APIError::WOULD_BLOCK;
} else if (received == -1) {
state_ = State::FAILED;
HELPER_LOG("Socket read failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
}
rx_buf_len_ += received;
if (received != to_read) {
// not all read
return APIError::WOULD_BLOCK;
}
}
// uncomment for even more debugging
// ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
frame->msg = std::move(rx_buf_);
// consume msg
rx_buf_ = {};
rx_buf_len_ = 0;
rx_header_buf_.clear();
rx_header_parsed_ = false;
return APIError::OK;
}
APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
int err;
APIError aerr;
if (state_ != State::DATA) {
return APIError::WOULD_BLOCK;
}
ParsedFrame frame;
aerr = try_read_frame_(&frame);
if (aerr != APIError::OK)
return aerr;
buffer->container = std::move(frame.msg);
buffer->data_offset = 0;
buffer->data_len = rx_header_parsed_len_;
buffer->type = rx_header_parsed_type_;
return APIError::OK;
}
bool APIPlaintextFrameHelper::can_write_without_blocking() {
return state_ == State::DATA && tx_buf_.empty();
}
APIError APIPlaintextFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) {
int err;
APIError aerr;
if (state_ != State::DATA) {
return APIError::BAD_STATE;
}
std::vector<uint8_t> header;
header.push_back(0x00);
ProtoVarInt(payload_len).encode(header);
ProtoVarInt(type).encode(header);
aerr = write_raw_(&header[0], header.size());
if (aerr != APIError::OK) {
return aerr;
}
aerr = write_raw_(payload, payload_len);
if (aerr != APIError::OK) {
return aerr;
}
return APIError::OK;
}
APIError APIPlaintextFrameHelper::try_send_tx_buf_() {
// try send from tx_buf
while (state_ != State::CLOSED && !tx_buf_.empty()) {
ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size());
if (sent == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN)
break;
state_ = State::FAILED;
HELPER_LOG("Socket write failed with errno %d", errno);
return APIError::SOCKET_WRITE_FAILED;
} else if (sent == 0) {
break;
}
// TODO: inefficient if multiple packets in txbuf
// replace with deque of buffers
tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent);
}
return APIError::OK;
}
/** Write the data to the socket, or buffer it a write would block
*
* @param data The data to write
* @param len The length of data
*/
APIError APIPlaintextFrameHelper::write_raw_(const uint8_t *data, size_t len) {
if (len == 0)
return APIError::OK;
int err;
APIError aerr;
// uncomment for even more debugging
// ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str());
if (!tx_buf_.empty()) {
// try to empty tx_buf_ first
aerr = try_send_tx_buf_();
if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK)
return aerr;
}
if (!tx_buf_.empty()) {
// tx buf not empty, can't write now because then stream would be inconsistent
tx_buf_.insert(tx_buf_.end(), data, data + len);
return APIError::OK;
}
ssize_t sent = socket_->write(data, len);
if (is_would_block(sent)) {
// operation would block, add buffer to tx_buf
tx_buf_.insert(tx_buf_.end(), data, data + len);
return APIError::OK;
} else if (sent == -1) {
// an error occured
state_ = State::FAILED;
HELPER_LOG("Socket write failed with errno %d", errno);
return APIError::SOCKET_WRITE_FAILED;
} else if (sent != len) {
// partially sent, add end to tx_buf
tx_buf_.insert(tx_buf_.end(), data + sent, data + len);
return APIError::OK;
}
// fully sent
return APIError::OK;
}
APIError APIPlaintextFrameHelper::write_frame_(const uint8_t *data, size_t len) {
APIError aerr;
uint8_t header[3];
header[0] = 0x01; // indicator
header[1] = (uint8_t) (len >> 8);
header[2] = (uint8_t) len;
aerr = write_raw_(header, 3);
if (aerr != APIError::OK)
return aerr;
aerr = write_raw_(data, len);
return aerr;
}
APIError APIPlaintextFrameHelper::close() {
state_ = State::CLOSED;
int err = socket_->close();
if (err == -1)
return APIError::CLOSE_FAILED;
return APIError::OK;
}
APIError APIPlaintextFrameHelper::shutdown(int how) {
int err = socket_->shutdown(how);
if (err == -1)
return APIError::SHUTDOWN_FAILED;
if (how == SHUT_RDWR) {
state_ = State::CLOSED;
}
return APIError::OK;
}
#endif // USE_API_PLAINTEXT
} // namespace api
} // namespace esphome

View file

@ -3,7 +3,11 @@
#include <vector>
#include <deque>
#include "esphome/core/defines.h"
#ifdef USE_API_NOISE
#include "noise/protocol.h"
#endif
#include "esphome/components/socket/socket.h"
#include "api_noise_context.h"
@ -63,39 +67,39 @@ class APIFrameHelper {
virtual void set_log_info(std::string info) = 0;
};
#ifdef USE_API_NOISE
class APINoiseFrameHelper : public APIFrameHelper {
public:
APINoiseFrameHelper(std::unique_ptr<socket::Socket> socket, std::shared_ptr<APINoiseContext> ctx) : socket_(std::move(socket)), ctx_(ctx) {}
~APINoiseFrameHelper();
APIError init();
APIError loop();
APIError read_packet(ReadPacketBuffer *buffer);
bool can_write_without_blocking();
APIError write_packet(uint16_t type, const uint8_t *data, size_t len);
std::string getpeername() {
APIError init() override;
APIError loop() override;
APIError read_packet(ReadPacketBuffer *buffer) override;
bool can_write_without_blocking() override;
APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override;
std::string getpeername() override{
return socket_->getpeername();
}
APIError close();
APIError shutdown(int how);
APIError close() override;
APIError shutdown(int how) override;
// Give this helper a name for logging
void set_log_info(std::string info) {
void set_log_info(std::string info) override {
info_ = std::move(info);
}
protected:
APIError reserve_rx_buf_(size_t new_capacity);
struct ParsedFrame {
std::vector<uint8_t> msg;
};
APIError state_action_();
APIError try_read_frame_(ParsedFrame *frame);
APIError try_send_raw_();
APIError try_send_tx_buf_();
APIError write_frame_(const uint8_t *data, size_t len);
APIError write_raw_(const uint8_t *data, size_t len);
APIError init_handshake_();
APIError check_handshake_finished_();
void send_explicit_handshake_reject_(const std::string &reason);
std::unique_ptr<socket::Socket> socket_;
@ -124,6 +128,59 @@ class APINoiseFrameHelper : public APIFrameHelper {
FAILED = 7,
} state_ = State::INITIALIZE;
};
#endif // USE_API_NOISE
#ifdef USE_API_PLAINTEXT
class APIPlaintextFrameHelper : public APIFrameHelper {
public:
APIPlaintextFrameHelper(std::unique_ptr<socket::Socket> socket) : socket_(std::move(socket)) {}
~APIPlaintextFrameHelper() = default;
APIError init() override;
APIError loop() override;
APIError read_packet(ReadPacketBuffer *buffer) override;
bool can_write_without_blocking() override;
APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override;
std::string getpeername() override {
return socket_->getpeername();
}
APIError close() override;
APIError shutdown(int how) override;
// Give this helper a name for logging
void set_log_info(std::string info) override {
info_ = std::move(info);
}
protected:
struct ParsedFrame {
std::vector<uint8_t> msg;
};
APIError try_read_frame_(ParsedFrame *frame);
APIError try_send_tx_buf_();
APIError write_frame_(const uint8_t *data, size_t len);
APIError write_raw_(const uint8_t *data, size_t len);
std::unique_ptr<socket::Socket> socket_;
std::string info_;
std::vector<uint8_t> rx_header_buf_;
bool rx_header_parsed_ = false;
uint32_t rx_header_parsed_type_ = 0;
uint32_t rx_header_parsed_len_ = 0;
std::vector<uint8_t> rx_buf_;
size_t rx_buf_len_ = 0;
std::vector<uint8_t> tx_buf_;
enum class State {
INITIALIZE = 1,
DATA = 2,
CLOSED = 3,
FAILED = 4,
} state_ = State::INITIALIZE;
};
#endif
} // namespace api
} // namespace esphome

View file

@ -1,10 +1,12 @@
#pragma once
#include <cstdint>
#include <array>
#include "esphome/core/defines.h"
namespace esphome {
namespace api {
#ifdef USE_API_NOISE
using psk_t = std::array<uint8_t, 32>;
class APINoiseContext {
@ -19,6 +21,7 @@ class APINoiseContext {
protected:
psk_t psk_;
};
#endif // USE_API_NOISE
} // namespace api
} // namespace esphome

View file

@ -31,12 +31,14 @@ class APIServer : public Component, public Controller {
void set_password(const std::string &password);
void set_reboot_timeout(uint32_t reboot_timeout);
#ifdef USE_API_NOISE
void set_noise_psk(psk_t psk) {
noise_ctx_->set_psk(std::move(psk));
}
std::shared_ptr<APINoiseContext> get_noise_ctx() {
return noise_ctx_;
}
#endif // USE_API_NOISE
void handle_disconnect(APIConnection *conn);
#ifdef USE_BINARY_SENSOR
@ -97,7 +99,10 @@ class APIServer : public Component, public Controller {
std::string password_;
std::vector<HomeAssistantStateSubscription> state_subs_;
std::vector<UserServiceDescriptor *> user_services_;
#ifdef USE_API_NOISE
std::shared_ptr<APINoiseContext> noise_ctx_ = std::make_shared<APINoiseContext>();
#endif // USE_API_NOISE
};
extern APIServer *global_api_server; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View file

@ -1,7 +1,6 @@
import esphome.config_validation as cv
import esphome.codegen as cg
# Dummy package to allow components to depend on network
CODEOWNERS = ["@esphome/core"]
CONF_IMPLEMENTATION = "implementation"

View file

@ -47,11 +47,6 @@ class BSDSocketImpl : public Socket {
closed_ = true;
return ret;
}
int connect(const std::string &address) override {
// TODO
return 0;
}
int connect(const struct sockaddr *addr, socklen_t addrlen) override { return ::connect(fd_, addr, addrlen); }
int shutdown(int how) override { return ::shutdown(fd_, how); }
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { return ::getpeername(fd_, addr, addrlen); }
@ -79,9 +74,7 @@ class BSDSocketImpl : public Socket {
return ::setsockopt(fd_, level, optname, optval, optlen);
}
int listen(int backlog) override { return ::listen(fd_, backlog); }
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
ssize_t read(void *buf, size_t len) override { return ::read(fd_, buf, len); }
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
ssize_t write(const void *buf, size_t len) override { return ::write(fd_, buf, len); }
int setblocking(bool blocking) override {
int fl = ::fcntl(fd_, F_GETFL, 0);

View file

@ -29,7 +29,6 @@ class LWIPRawImpl : public Socket {
}
void init() {
ESP_LOGD(TAG, "init()");
tcp_arg(pcb_, this);
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
tcp_recv(pcb_, LWIPRawImpl::s_recv_fn);
@ -98,8 +97,7 @@ class LWIPRawImpl : public Socket {
port = ntohs(addr4->sin_port);
ip.addr = addr4->sin_addr.s_addr;
#endif
err_t err = tcp_bind(pcb_, IP4_ADDR_ANY, port);
ESP_LOGD(TAG, "bind(ip=%u, port=%u) -> %d", ip.addr, port, err);
err_t err = tcp_bind(pcb_, &ip, port);
if (err == ERR_USE) {
errno = EADDRINUSE;
return -1;
@ -129,14 +127,6 @@ class LWIPRawImpl : public Socket {
pcb_ = nullptr;
return 0;
}
int connect(const std::string &address) override {
// TODO
return -1;
}
int connect(const struct sockaddr *addr, socklen_t addrlen) override {
// TODO
return -1;
}
int shutdown(int how) override {
if (pcb_ == nullptr) {
errno = EBADF;
@ -226,7 +216,29 @@ class LWIPRawImpl : public Socket {
errno = EBADF;
return -1;
}
// TODO
if (level == SOL_SOCKET && optname == SO_REUSEADDR) {
if (optlen < 4) {
errno = EINVAL;
return -1;
}
// lwip doesn't seem to have this feature. Don't send an error
// to prevent warnings
*reinterpret_cast<int *>(optval) = 1;
*optlen = 4;
return 0;
}
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
if (optlen < 4) {
errno = EINVAL;
return -1;
}
*reinterpret_cast<int *>(optval) = tcp_nagle_disabled(pcb_);
*optlen = 4;
return 0;
}
errno = EINVAL;
return -1;
}
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
@ -240,7 +252,8 @@ class LWIPRawImpl : public Socket {
return -1;
}
// TODO
// lwip doesn't seem to have this feature. Don't send an error
// to prevent warnings
return 0;
}
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
@ -266,7 +279,6 @@ class LWIPRawImpl : public Socket {
return -1;
}
struct tcp_pcb *listen_pcb = tcp_listen_with_backlog(pcb_, backlog);
ESP_LOGD(TAG, "listen(%d) -> %p", backlog, listen_pcb);
if (listen_pcb == nullptr) {
tcp_abort(pcb_);
pcb_ = nullptr;
@ -286,7 +298,7 @@ class LWIPRawImpl : public Socket {
return -1;
}
if (rx_closed_ && rx_buf_ == nullptr) {
errno = ECONNRESET; // TODO: is this the right errno?
errno = ECONNRESET;
return -1;
}
if (len == 0) {
@ -333,7 +345,6 @@ class LWIPRawImpl : public Socket {
return read;
}
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
ssize_t write(const void *buf, size_t len) {
if (pcb_ == nullptr) {
errno = EBADF;
@ -367,7 +378,6 @@ class LWIPRawImpl : public Socket {
}
return to_send;
}
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
int setblocking(bool blocking) {
if (pcb_ == nullptr) {
errno = EBADF;
@ -382,20 +392,32 @@ class LWIPRawImpl : public Socket {
}
err_t accept_fn(struct tcp_pcb *newpcb, err_t err) {
// TODO: check err
if (err != ERR_OK || newpcb == 0) {
// "An error code if there has been an error accepting. Only return ERR_ABRT if you have
// called tcp_abort from within the callback function!"
// https://www.nongnu.org/lwip/2_1_x/tcp_8h.html#a00517abce6856d6c82f0efebdafb734d
// nothing to do here, we just don't push it to the queue
return ERR_OK;
}
accepted_sockets_.emplace(new LWIPRawImpl(newpcb));
ESP_LOGD(TAG, "accept_fn newpcb=%p err=%d", newpcb, err);
return ERR_OK;
}
void err_fn(err_t err) {
ESP_LOGD(TAG, "err_fn err=%d", err);
// "If a connection is aborted because of an error, the application is alerted of this event by
// the err callback."
// pcb is already freed when this callback is called
// ERR_RST: connection was reset by remote host
// ERR_ABRT: aborted through tcp_abort or TCP timer
pcb_ = nullptr;
}
err_t recv_fn(struct pbuf *pb, err_t err) {
// TODO: check err
ESP_LOGD(TAG, "recv_fn pb=%p err=%d", pb, err);
if (err != 0) {
// "An error code if there has been an error receiving Only return ERR_ABRT if you have
// called tcp_abort from within the callback function!"
rx_closed_ = true;
return ERR_OK;
}
if (pb == nullptr) {
// remote host has closed the connection
// TODO
rx_closed_ = true;
return ERR_OK;
}

View file

@ -18,8 +18,9 @@ class Socket {
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
virtual int close() = 0;
virtual int connect(const std::string &address) = 0;
virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
// not supported yet:
// virtual int connect(const std::string &address) = 0;
// virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
virtual int shutdown(int how) = 0;
virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0;
@ -30,9 +31,7 @@ class Socket {
virtual int setsockopt(int level, int optname, const void *optval, socklen_t optlen) = 0;
virtual int listen(int backlog) = 0;
virtual ssize_t read(void *buf, size_t len) = 0;
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
virtual ssize_t write(const void *buf, size_t len) = 0;
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
virtual int setblocking(bool blocking) = 0;
virtual int loop() { return 0; };
};

View file

@ -130,7 +130,6 @@ class MbedTLSWrappedSocket : public socket::Socket {
int ret = mbedtls_ssl_read(&ssl_, reinterpret_cast<uint8_t *>(buf), len);
return this->mbedtls_to_errno_(ret);
}
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
ssize_t write(const void *buf, size_t len) override {
loop();
if (do_handshake_) {
@ -140,7 +139,6 @@ class MbedTLSWrappedSocket : public socket::Socket {
int ret = mbedtls_ssl_write(&ssl_, reinterpret_cast<const uint8_t *>(buf), len);
return this->mbedtls_to_errno_(ret);
}
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
int setblocking(bool blocking) override {
// TODO: handle blocking modes
return sock_->setblocking(blocking);

View file

@ -31,3 +31,5 @@
#define USE_MDNS
#define USE_SOCKET_IMPL_LWIP_TCP
#define USE_SOCKET_IMPL_BSD_SOCKETS
#define USE_API_NOISE
#define USE_API_PLAINTEXT