diff --git a/esphome/components/mqtt/mqtt_backend.cpp b/esphome/components/mqtt/mqtt_backend.cpp new file mode 100644 index 0000000000..a2e2065857 --- /dev/null +++ b/esphome/components/mqtt/mqtt_backend.cpp @@ -0,0 +1,561 @@ +#include "mqtt_backend.h" + +#ifdef USE_MQTT + +#include "esphome/core/log.h" +#include "esphome/core/hal.h" +#include "esphome/core/helpers.h" + +namespace esphome { +namespace mqtt { + +static const char *TAG = "mqtt.backend"; + +ErrorCode util::ConnectionEstablisher::init(const std::string &host, uint16_t port, uint32_t timeout) { + if (state_ != State::UNINIT) { + enter_error_(); + ESP_LOGD(TAG, "conn init bad state"); + return ErrorCode::BAD_STATE; + } + struct addrinfo hints; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + std::string port_s = to_string(port); + getaddrinfo_ = socket::getaddrinfo_async(host.c_str(), port_s.c_str(), &hints); + if (!getaddrinfo_) { + enter_error_(); + return ErrorCode::RESOLVE_ERROR; + } + start_ = millis(); + timeout_ = timeout; + state_ = State::RESOLVING; + return ErrorCode::OK; +} + +void util::ConnectionEstablisher::enter_error_() { + state_ = State::ERROR; + getaddrinfo_.reset(); + socket_.reset(); +} + +ErrorCode util::ConnectionEstablisher::loop() { + ErrorCode ec; + uint32_t now = millis(); + + switch (state_) { + case State::UNINIT: { + enter_error_(); + state_ = State::ERROR; + ESP_LOGD(TAG, "conn uninit bad state"); + return ErrorCode::BAD_STATE; + } + + case State::RESOLVING: { + if (getaddrinfo_->completed()) { + struct addrinfo *res; + int r = getaddrinfo_->fetch_result(&res); + if (r != 0) { + enter_error_(); + ESP_LOGW(TAG, "Address resolve failed with error %s", gai_strerror(r)); + return ErrorCode::RESOLVE_ERROR; + } + if (res == nullptr) { + enter_error_(); + ESP_LOGW(TAG, "Address resolve returned no results"); + return ErrorCode::RESOLVE_ERROR; + } + + ESP_LOGD(TAG, "address resolved!"); + + socket_ = socket::socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (!socket_) { + freeaddrinfo(res); + enter_error_(); + ESP_LOGW(TAG, "Socket creation failed with error %s", strerror(errno)); + return ErrorCode::SOCKET_ERROR; + } + + r = socket_->setblocking(false); + if (r != 0) { + enter_error_(); + ESP_LOGV(TAG, "Setting nonblocking socket failed with error %s", strerror(errno)); + return ErrorCode::SOCKET_ERROR; + } + + r = socket_->connect(res->ai_addr, res->ai_addrlen); + freeaddrinfo(res); + + if (r == 0) { + // connection established immediately + getaddrinfo_.reset(); + state_ = State::CONNECTED; + return ErrorCode::OK; + } else if (errno == EINPROGRESS) { + getaddrinfo_.reset(); + state_ = State::CONNECTING; + } else { + enter_error_(); + ESP_LOGW(TAG, "Socket connect failed with error %s", strerror(errno)); + return ErrorCode::SOCKET_ERROR; + } + } + + if (now - start_ >= timeout_) { + enter_error_(); + ESP_LOGW(TAG, "Timeout resolving address"); + return ErrorCode::TIMEOUT; + } + + return ErrorCode::IN_PROGRESS; + } + + case State::CONNECTING: { + int r = socket_->connect_finished(); + if (r == 0) { + // connection established + state_ = State::CONNECTED; + return ErrorCode::OK; + } else if (errno == EINPROGRESS) { + // not established yet + + if (now - start_ >= timeout_) { + enter_error_(); + ESP_LOGW(TAG, "Timeout connecting to address"); + return ErrorCode::TIMEOUT; + } + + return ErrorCode::IN_PROGRESS; + } else { + enter_error_(); + ESP_LOGW(TAG, "Socket connect failed with error %s", strerror(errno)); + return ErrorCode::SOCKET_ERROR; + } + } + + case State::CONNECTED: { + return ErrorCode::OK; + } + + case State::FINISHED: + case State::ERROR: + default: { + return ErrorCode::BAD_STATE; + } + } +} + +std::unique_ptr util::ConnectionEstablisher::extract_socket() { + if (state_ != State::CONNECTED) + return nullptr; + state_ = State::FINISHED; + return std::move(socket_); +} + +ErrorCode util::BufferedWriter::write(const std::unique_ptr &sock, const uint8_t *data, size_t len, bool do_buffer) { + if (len == 0) + return ErrorCode::OK; + ErrorCode ec; + + if (!tx_buf_.empty()) { + // try to empty tx_buf_ first + ec = try_drain(sock); + if (ec != ErrorCode::OK && ec != ErrorCode::WOULD_BLOCK) + return ec; + } + + if (!tx_buf_.empty()) { + // tx buf not empty, can't write now because then stream would be inconsistent + if (!do_buffer) + return ErrorCode::WOULD_BLOCK; + + tx_buf_.insert(tx_buf_.end(), data, data + len); + return ErrorCode::OK; + } + + ssize_t sent = sock->write(data, len); + if (sent == 0 || (sent == -1 && (errno == EWOULDBLOCK || errno == EAGAIN))) { + // operation would block, add to tx_buf if buffering + if (!do_buffer) + return ErrorCode::WOULD_BLOCK; + tx_buf_.insert(tx_buf_.end(), data, data + len); + return ErrorCode::OK; + } else if (sent == -1) { + // an error occured + ESP_LOGV(TAG, "Socket write failed with errno %d", errno); + return ErrorCode::SOCKET_ERROR; + } else if ((size_t) sent != len) { + // partially sent, add end to tx_buf (even if not set to buffering, to prevent + // partial packet transmission) + tx_buf_.insert(tx_buf_.end(), data + sent, data + len); + return ErrorCode::OK; + } + // fully sent + return ErrorCode::OK; +} +ErrorCode util::BufferedWriter::try_drain(const std::unique_ptr &sock) { + // try send from tx_buf + while (!tx_buf_.empty()) { + ssize_t sent = sock->write(tx_buf_.data(), tx_buf_.size()); + if (sent == 0 || (sent == -1 && (errno = EWOULDBLOCK || errno == EAGAIN))) { + break; + } else if (sent == -1) { + ESP_LOGV(TAG, "Socket write failed with errno %d", errno); + return ErrorCode::SOCKET_ERROR; + } + + // TODO: inefficient if multiple packets in txbuf + // replace with deque of buffers + tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent); + } + + return ErrorCode::OK; +} + +ErrorCode MQTTConnection::init(ConnectParams *params, MQTTSession *session) { + if (state_ != State::UNINIT) { + enter_error_(); + return ErrorCode::BAD_STATE; + } + params_ = params; + session_ = session; + connection_establisher_ = make_unique(); + ErrorCode ec = connection_establisher_->init(params_->host, params_->port, 5000); + if (ec != ErrorCode::OK) { + enter_error_(); + return ec; + } + + state_ = State::CONNECTING; + return ErrorCode::OK; +} + +ErrorCode MQTTConnection::loop() { + ErrorCode ec; + + switch (state_) { + case State::CONNECTING: { + ec = connection_establisher_->loop(); + if (ec == ErrorCode::OK) { + // connection established + socket_ = connection_establisher_->extract_socket(); + + ConnectPacket packet{}; + packet.client_id = params_->client_id; + packet.username = params_->username; + packet.password = params_->password; + packet.will_topic = params_->will_topic; + packet.will_message = params_->will_message; + packet.will_qos = params_->will_qos; + packet.will_retain = params_->will_retain; + packet.clean_session = true; + packet.keep_alive = params_->keep_alive; + + std::vector packet_enc; + ec = packet.encode(packet_enc); + if (ec != ErrorCode::OK) { + enter_error_(); + return ec; + } + + ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), true); + if (ec != ErrorCode::OK) { + enter_error_(); + return ec; + } + + state_ = State::WAIT_CONNACK; + connection_establisher_.reset(); + } else if (ec != ErrorCode::IN_PROGRESS) { + enter_error_(); + return ec; + } + return ErrorCode::OK; + } + + case State::WAIT_CONNACK: + case State::CONNECTED: { + ec = writer_.try_drain(socket_); + if (ec != ErrorCode::OK) { + enter_error_(); + return ec; + } + + ec = read_packet_(); + if (ec != ErrorCode::OK && ec != ErrorCode::WOULD_BLOCK) { + enter_error_(); + return ec; + } + + return ErrorCode::OK; + } + + case State::UNINIT: + case State::DISCONNECTED: + case State::ERROR: + default: { + enter_error_(); + ESP_LOGD(TAG, "bad state %d", (int) state_); + return ErrorCode::BAD_STATE; + } + } +} + +void MQTTConnection::enter_error_() { + ESP_LOGD(TAG, "enter_error"); + connection_establisher_.reset(); + socket_.reset(); + rx_header_buf_ = {}; + rx_buf_ = {}; + writer_.stop(); + state_ = State::ERROR; +} + +ErrorCode MQTTConnection::read_packet_() { + if (state_ != State::CONNECTED && state_ != State::WAIT_CONNACK) { + enter_error_(); + ESP_LOGD(TAG, "read_packet_ bad state"); + return ErrorCode::BAD_STATE; + } + ErrorCode ec; + + if (!rx_header_parsed_) { + while (true) { + uint8_t v; + ssize_t received = socket_->read(&v, 1); + if (received == -1 && (errno == EWOULDBLOCK || errno == EAGAIN)) { + // would block + return ErrorCode::WOULD_BLOCK; + } else if (received == -1) { + // error + enter_error_(); + ESP_LOGV(TAG, "Socket read failed with errno %d", errno); + return ErrorCode::SOCKET_ERROR; + } else if (received == 0) { + // EOF + enter_error_(); + ESP_LOGV(TAG, "Socket EOF"); + return ErrorCode::SOCKET_ERROR; + } + rx_header_buf_.push_back(v); + + // try parse buf + if (rx_header_buf_.size() == 1) + continue; + + rx_header_parsed_type_ = (rx_header_buf_[0] >> 4) & 0x0F; + rx_header_parsed_flags_ = (rx_header_buf_[0] >> 0) & 0x0F; + + size_t multiplier = 1, value = 0; + size_t i = 1; + uint8_t enc; + bool parsed = true; + do { + if (i >= rx_header_buf_.size()) { + // not enough data yet + parsed = false; + break; + } + enc = rx_header_buf_[i]; + value += (enc & 0x7F) * multiplier; + multiplier <<= 7; + } while (enc & 0x80); + if (!parsed) + continue; + + rx_header_parsed_ = true; + rx_header_parsed_len_ = value; + } + } + // header reading done + + // reserve space for body + if (rx_buf_.size() != rx_header_parsed_len_) { + rx_buf_.resize(rx_header_parsed_len_); + } + + if (rx_buf_len_ < rx_header_parsed_len_) { + // more data to read + size_t to_read = rx_header_parsed_len_ - rx_buf_len_; + ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read); + if (received == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) { + return ErrorCode::WOULD_BLOCK; + } + enter_error_(); + ESP_LOGV(TAG, "Socket read failed with errno %d", errno); + return ErrorCode::SOCKET_ERROR; + } else if (received == 0) { + enter_error_(); + ESP_LOGD(TAG, "Connection closed"); + return ErrorCode::CONNECTION_CLOSED; + } + rx_buf_len_ += received; + if ((size_t) received != to_read) { + // not all read + return ErrorCode::WOULD_BLOCK; + } + } + // body reading done + + ec = handle_packet_(rx_header_parsed_type_, rx_header_parsed_flags_, rx_buf_.data(), rx_buf_.size()); + // prepare for next packet + rx_header_parsed_ = false; + return ec; +} + +ErrorCode MQTTConnection::handle_packet_(uint8_t packet_type, uint8_t flags, const uint8_t *data, size_t len) { + util::Parser parser(data, len); + ErrorCode ec; + + switch (static_cast(packet_type)) { + case PacketType::CONNACK: { + if (state_ != State::WAIT_CONNACK) { + enter_error_(); + ESP_LOGV(TAG, "Bad state for connack %d", (int) state_); + return ErrorCode::BAD_STATE; + } + ConnackPacket packet{}; + ec = packet.decode(flags, parser); + if (ec != ErrorCode::OK) { + enter_error_(); + ESP_LOGV(TAG, "Error decoding connack packet %d", (int) ec); + return ec; + } + + if (packet.connect_return_code != ConnectReturnCode::ACCEPTED) { + const char *reason; + switch (packet.connect_return_code) { + case ConnectReturnCode::UNACCEPTABLE_PROTOCOL_VERSION: + reason = "unacceptable protocol version"; + break; + case ConnectReturnCode::IDENTIFIER_REJECTED: + reason = "identifier rejected"; + break; + case ConnectReturnCode::SERVER_UNAVAILABLE: + reason = "server unavailable"; + break; + case ConnectReturnCode::BAD_USER_NAME_OR_PASSWORD: + reason = "bad user name or password"; + break; + case ConnectReturnCode::NOT_AUTHORIZED: + reason = "not authorized"; + break; + default: + reason = "unknown"; + break; + } + enter_error_(); + ESP_LOGW(TAG, "Connect failed: %s", reason); + return ErrorCode::PROTOCOL_ERROR; + } + + ESP_LOGD(TAG, "Connected!"); + state_ = State::CONNECTED; + return ErrorCode::OK; + } + case PacketType::PUBLISH: { + if (state_ != State::CONNECTED) { + enter_error_(); + ESP_LOGV(TAG, "Bad state for publish %d", (int) state_); + return ErrorCode::BAD_STATE; + } + + return ErrorCode::OK; + } + + case PacketType::PUBACK: + case PacketType::PUBREC: + case PacketType::PUBREL: + case PacketType::PUBCOMP: + case PacketType::SUBACK: + case PacketType::UNSUBACK: { + // TODO + ESP_LOGD(TAG, "Received packet with type %u", packet_type); + return ErrorCode::OK; + } + + case PacketType::PINGRESP: { + // TODO rx timer + return ErrorCode::OK; + } + + case PacketType::CONNECT: + case PacketType::DISCONNECT: + case PacketType::SUBSCRIBE: + case PacketType::UNSUBSCRIBE: + case PacketType::PINGREQ: + default: { + enter_error_(); + ESP_LOGW(TAG, "Received unknown packet type %u", packet_type); + return ErrorCode::UNEXPECTED; + } + + } +} + +ErrorCode MQTTConnection::publish(std::string topic, std::vector message, bool retain, QOSLevel qos) { + PublishPacket packet{}; + packet.topic = std::move(topic); + packet.message = std::move(message); + packet.retain = retain; + packet.qos = qos; + if (packet.qos != QOSLevel::QOS0) { + packet.packet_identifier = session_->create_packet_id(); + } + packet.dup = false; + + std::vector packet_enc; + ErrorCode ec = packet.encode(packet_enc); + if (ec != ErrorCode::OK) + return ec; + + ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), qos != QOSLevel::QOS0); + if (ec != ErrorCode::OK && ec != ErrorCode::WOULD_BLOCK) { + enter_error_(); + ESP_LOGV(TAG, "publish write failed"); + return ec; + } + + return ec; +} +ErrorCode MQTTConnection::subscribe(std::vector subscriptions) { + SubscribePacket packet{}; + packet.subscriptions = std::move(subscriptions); + packet.packet_identifier = session_->create_packet_id(); + + std::vector packet_enc; + ErrorCode ec = packet.encode(packet_enc); + if (ec != ErrorCode::OK) + return ec; + + ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), true); + if (ec != ErrorCode::OK) { + enter_error_(); + ESP_LOGV(TAG, "subscribe write failed"); + return ec; + } + return ec; +} +ErrorCode MQTTConnection::unsubscribe(std::vector topic_filters) { + UnsubscribePacket packet{}; + packet.topic_filters = std::move(topic_filters); + packet.packet_identifier = session_->create_packet_id(); + + std::vector packet_enc; + ErrorCode ec = packet.encode(packet_enc); + if (ec != ErrorCode::OK) + return ec; + + ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), true); + if (ec != ErrorCode::OK) { + enter_error_(); + ESP_LOGV(TAG, "unsubscribe write failed"); + return ec; + } + return ec; +} + +} // namespace mqtt +} // namespace esphome + +#endif // USE_MQTT diff --git a/esphome/components/mqtt/mqtt_backend.h b/esphome/components/mqtt/mqtt_backend.h new file mode 100644 index 0000000000..4f49a37e4d --- /dev/null +++ b/esphome/components/mqtt/mqtt_backend.h @@ -0,0 +1,143 @@ +#pragma once + +#include "esphome/core/defines.h" + +#ifdef USE_MQTT + +#include "packets.h" +#include "esphome/components/socket/socket.h" +#include "esphome/components/socket/getaddrinfo.h" +#include +#include + +namespace esphome { +namespace mqtt { + +namespace util { + +class ConnectionEstablisher { + public: + ErrorCode init(const std::string &host, uint16_t port, uint32_t timeout); + ErrorCode loop(); + // Should only be called when loop() returns OK, is guaranteed to succeed + std::unique_ptr extract_socket(); + + protected: + void enter_error_(); + + std::unique_ptr socket_; + std::unique_ptr getaddrinfo_; + uint32_t timeout_; + uint32_t start_; + + enum class State { + UNINIT = 0, + RESOLVING = 1, + CONNECTING = 2, + CONNECTED = 3, + FINISHED = 4, + ERROR = 5, + } state_ = State::UNINIT; +}; + +class BufferedWriter { + public: + ErrorCode write(const std::unique_ptr &sock, const uint8_t *data, size_t len, + bool do_buffer); + ErrorCode try_drain(const std::unique_ptr &sock); + void stop() { + tx_buf_ = {}; + } + + protected: + std::vector tx_buf_; +}; + +} // namespace util + +struct ConnectParams { + std::string host; + uint16_t port; + + std::string client_id; + optional username; + optional> password; + std::string will_topic; + std::vector will_message; + QOSLevel will_qos; + bool will_retain; + uint16_t keep_alive; +}; + +class MQTTSession { + public: + bool get_has_session() const { return has_session_; } + void set_has_session(bool has_session) { has_session_ = has_session; } + void clean_session() { + packet_id_counter_ = 0; + used_packet_ids_.clear(); + } + uint16_t create_packet_id() { + while (true) { + packet_id_counter_++; + if (packet_id_counter_ == 0 || used_packet_ids_.count(packet_id_counter_) > 0) { + continue; + } + used_packet_ids_.insert(packet_id_counter_); + return packet_id_counter_; + } + } + void return_packet_id(uint16_t packet_id) { + used_packet_ids_.erase(packet_id); + } + protected: + bool has_session_ = false; + uint32_t packet_id_counter_ = 0; + std::set used_packet_ids_; +}; + +class MQTTConnection { + public: + ErrorCode init(ConnectParams *params, MQTTSession *session); + ErrorCode loop(); + bool is_connected() { return state_ == State::CONNECTED; } + + ErrorCode publish(std::string topic, std::vector message, bool retain, QOSLevel qos); + ErrorCode subscribe(std::vector subscriptions); + ErrorCode unsubscribe(std::vector topic_filters); + + protected: + void enter_error_(); + ErrorCode read_packet_(); + ErrorCode handle_packet_(uint8_t packet_type, uint8_t flags, const uint8_t *data, size_t len); + + ConnectParams *params_; + MQTTSession *session_; + std::unique_ptr connection_establisher_; + std::unique_ptr socket_; + + std::vector rx_header_buf_; + bool rx_header_parsed_ = false; + uint8_t rx_header_parsed_type_ = 0; + uint8_t rx_header_parsed_flags_ = 0; + size_t rx_header_parsed_len_ = 0; + + std::vector rx_buf_; + size_t rx_buf_len_ = 0; + + util::BufferedWriter writer_; + + enum class State { + UNINIT = 0, + CONNECTING = 1, + WAIT_CONNACK = 2, + CONNECTED = 3, + DISCONNECTED = 4, + ERROR = 5, + } state_ = State::UNINIT; +}; + +} // namespace mqtt +} // namespace esphome + +#endif // USE_MQTT diff --git a/esphome/components/mqtt/mqtt_client.cpp b/esphome/components/mqtt/mqtt_client.cpp index 67063d4c72..0e99ca66bd 100644 --- a/esphome/components/mqtt/mqtt_client.cpp +++ b/esphome/components/mqtt/mqtt_client.cpp @@ -28,7 +28,7 @@ MQTTClientComponent::MQTTClientComponent() { // Connection void MQTTClientComponent::setup() { ESP_LOGCONFIG(TAG, "Setting up MQTT..."); - this->mqtt_client_.onMessage([this](char const *topic, char *payload, AsyncMqttClientMessageProperties properties, + /*this->mqtt_client_.onMessage([this](char const *topic, char *payload, AsyncMqttClientMessageProperties properties, size_t len, size_t index, size_t total) { if (index == 0) this->payload_buffer_.reserve(total); @@ -45,7 +45,7 @@ void MQTTClientComponent::setup() { this->mqtt_client_.onDisconnect([this](AsyncMqttClientDisconnectReason reason) { this->state_ = MQTT_CLIENT_DISCONNECTED; this->disconnect_reason_ = reason; - }); + });*/ #ifdef USE_LOGGER if (this->is_log_message_enabled() && logger::global_logger != nullptr) { logger::global_logger->add_on_log_callback([this](int level, const char *tag, const char *message) { @@ -58,12 +58,11 @@ void MQTTClientComponent::setup() { #endif this->last_connected_ = millis(); - this->start_dnslookup_(); + this->start_connect_(); } void MQTTClientComponent::dump_config() { ESP_LOGCONFIG(TAG, "MQTT:"); - ESP_LOGCONFIG(TAG, " Server Address: %s:%u (%s)", this->credentials_.address.c_str(), this->credentials_.port, - this->ip_.str().c_str()); + ESP_LOGCONFIG(TAG, " Server Address: %s:%u", this->credentials_.address.c_str(), this->credentials_.port); ESP_LOGCONFIG(TAG, " Username: " LOG_SECRET("'%s'"), this->credentials_.username.c_str()); ESP_LOGCONFIG(TAG, " Client ID: " LOG_SECRET("'%s'"), this->credentials_.client_id.c_str()); if (!this->discovery_info_.prefix.empty()) { @@ -80,131 +79,68 @@ void MQTTClientComponent::dump_config() { } bool MQTTClientComponent::can_proceed() { return this->is_connected(); } -void MQTTClientComponent::start_dnslookup_() { + +void MQTTClientComponent::start_connect_() { + if (!network::is_connected()) + return; + for (auto &subscription : this->subscriptions_) { subscription.subscribed = false; subscription.resubscribe_timeout = 0; } this->status_set_warning(); - this->dns_resolve_error_ = false; - this->dns_resolved_ = false; - ip_addr_t addr; -#ifdef USE_ESP32 - err_t err = dns_gethostbyname_addrtype(this->credentials_.address.c_str(), &addr, - MQTTClientComponent::dns_found_callback, this, LWIP_DNS_ADDRTYPE_IPV4); -#endif -#ifdef USE_ESP8266 - err_t err = dns_gethostbyname(this->credentials_.address.c_str(), &addr, - esphome::mqtt::MQTTClientComponent::dns_found_callback, this); -#endif - switch (err) { - case ERR_OK: { - // Got IP immediately - this->dns_resolved_ = true; -#ifdef USE_ESP32 - this->ip_ = addr.u_addr.ip4.addr; -#endif -#ifdef USE_ESP8266 - this->ip_ = addr.addr; -#endif - this->start_connect_(); - return; - } - case ERR_INPROGRESS: { - // wait for callback - ESP_LOGD(TAG, "Resolving MQTT broker IP address..."); - break; - } - default: - case ERR_ARG: { - // error -#if defined(USE_ESP8266) - ESP_LOGW(TAG, "Error resolving MQTT broker IP address: %ld", err); -#else - ESP_LOGW(TAG, "Error resolving MQTT broker IP address: %d", err); -#endif - break; - } - } - - this->state_ = MQTT_CLIENT_RESOLVING_ADDRESS; - this->connect_begin_ = millis(); -} -void MQTTClientComponent::check_dnslookup_() { - if (!this->dns_resolved_ && millis() - this->connect_begin_ > 20000) { - this->dns_resolve_error_ = true; - } - - if (this->dns_resolve_error_) { - ESP_LOGW(TAG, "Couldn't resolve IP address for '%s'!", this->credentials_.address.c_str()); - this->state_ = MQTT_CLIENT_DISCONNECTED; - return; - } - - if (!this->dns_resolved_) { - return; - } - - ESP_LOGD(TAG, "Resolved broker IP address to %s", this->ip_.str().c_str()); - this->start_connect_(); -} -#if defined(USE_ESP8266) && LWIP_VERSION_MAJOR == 1 -void MQTTClientComponent::dns_found_callback(const char *name, ip_addr_t *ipaddr, void *callback_arg) { -#else -void MQTTClientComponent::dns_found_callback(const char *name, const ip_addr_t *ipaddr, void *callback_arg) { -#endif - auto *a_this = (MQTTClientComponent *) callback_arg; - if (ipaddr == nullptr) { - a_this->dns_resolve_error_ = true; - } else { -#ifdef USE_ESP32 - a_this->ip_ = ipaddr->u_addr.ip4.addr; -#endif -#ifdef USE_ESP8266 - a_this->ip_ = ipaddr->addr; -#endif - a_this->dns_resolved_ = true; - } -} - -void MQTTClientComponent::start_connect_() { - if (!network::is_connected()) - return; ESP_LOGI(TAG, "Connecting to MQTT..."); - // Force disconnect first - this->mqtt_client_.disconnect(true); - this->mqtt_client_.setClientId(this->credentials_.client_id.c_str()); - const char *username = nullptr; - if (!this->credentials_.username.empty()) - username = this->credentials_.username.c_str(); - const char *password = nullptr; - if (!this->credentials_.password.empty()) - password = this->credentials_.password.c_str(); + conn_params_.host = credentials_.address; + conn_params_.port = credentials_.port; + conn_params_.client_id = credentials_.client_id; - this->mqtt_client_.setCredentials(username, password); - - this->mqtt_client_.setServer((uint32_t) this->ip_, this->credentials_.port); - if (!this->last_will_.topic.empty()) { - this->mqtt_client_.setWill(this->last_will_.topic.c_str(), this->last_will_.qos, this->last_will_.retain, - this->last_will_.payload.c_str(), this->last_will_.payload.length()); + if (!credentials_.username.empty()) + conn_params_.username = credentials_.username; + else + conn_params_.username.reset(); + if (!credentials_.password.empty()) { + std::vector pwd{credentials_.password.begin(), credentials_.password.end()}; + conn_params_.password = pwd; + } else { + conn_params_.password.reset(); + } + + if (!last_will_.topic.empty()) { + conn_params_.will_topic = last_will_.topic; + std::vector msg{last_will_.payload.begin(), last_will_.payload.end()}; + conn_params_.will_message = msg; + conn_params_.will_retain = last_will_.retain; + conn_params_.will_qos = static_cast(last_will_.qos); + } else { + conn_params_.will_topic = ""; + conn_params_.will_message.clear(); + conn_params_.will_retain = false; + conn_params_.will_qos = QOSLevel::QOS0; + } + + conn_ = make_unique(); + ErrorCode ec = conn_->init(&conn_params_, &sess_); + if (ec != ErrorCode::OK) { + ESP_LOGW(TAG, "connection init failed: %d", (int) ec); + return; } - this->mqtt_client_.connect(); this->state_ = MQTT_CLIENT_CONNECTING; this->connect_begin_ = millis(); } bool MQTTClientComponent::is_connected() { - return this->state_ == MQTT_CLIENT_CONNECTED && this->mqtt_client_.connected(); + return this->state_ == MQTT_CLIENT_CONNECTED && this->conn_->is_connected(); } void MQTTClientComponent::check_connected() { - if (!this->mqtt_client_.connected()) { - if (millis() - this->connect_begin_ > 60000) { - this->state_ = MQTT_CLIENT_DISCONNECTED; - this->start_dnslookup_(); + if (conn_ && !conn_->is_connected()) { + ErrorCode ec = conn_->loop(); + if (ec != ErrorCode::OK) { + ESP_LOGW(TAG, "check connected loop failed: %d", (int) ec); + state_ = MQTT_CLIENT_DISCONNECTED; } return; } @@ -223,64 +159,25 @@ void MQTTClientComponent::check_connected() { } void MQTTClientComponent::loop() { - if (this->disconnect_reason_.has_value()) { - const LogString *reason_s; - switch (*this->disconnect_reason_) { - case AsyncMqttClientDisconnectReason::TCP_DISCONNECTED: - reason_s = LOG_STR("TCP disconnected"); - break; - case AsyncMqttClientDisconnectReason::MQTT_UNACCEPTABLE_PROTOCOL_VERSION: - reason_s = LOG_STR("Unacceptable Protocol Version"); - break; - case AsyncMqttClientDisconnectReason::MQTT_IDENTIFIER_REJECTED: - reason_s = LOG_STR("Identifier Rejected"); - break; - case AsyncMqttClientDisconnectReason::MQTT_SERVER_UNAVAILABLE: - reason_s = LOG_STR("Server Unavailable"); - break; - case AsyncMqttClientDisconnectReason::MQTT_MALFORMED_CREDENTIALS: - reason_s = LOG_STR("Malformed Credentials"); - break; - case AsyncMqttClientDisconnectReason::MQTT_NOT_AUTHORIZED: - reason_s = LOG_STR("Not Authorized"); - break; - case AsyncMqttClientDisconnectReason::ESP8266_NOT_ENOUGH_SPACE: - reason_s = LOG_STR("Not Enough Space"); - break; - case AsyncMqttClientDisconnectReason::TLS_BAD_FINGERPRINT: - reason_s = LOG_STR("TLS Bad Fingerprint"); - break; - default: - reason_s = LOG_STR("Unknown"); - break; - } - if (!network::is_connected()) { - reason_s = LOG_STR("WiFi disconnected"); - } - ESP_LOGW(TAG, "MQTT Disconnected: %s.", LOG_STR_ARG(reason_s)); - this->disconnect_reason_.reset(); - } - const uint32_t now = millis(); switch (this->state_) { case MQTT_CLIENT_DISCONNECTED: - if (now - this->connect_begin_ > 5000) { - this->start_dnslookup_(); - } - break; - case MQTT_CLIENT_RESOLVING_ADDRESS: - this->check_dnslookup_(); break; case MQTT_CLIENT_CONNECTING: this->check_connected(); break; case MQTT_CLIENT_CONNECTED: - if (!this->mqtt_client_.connected()) { + if (!this->conn_->is_connected()) { this->state_ = MQTT_CLIENT_DISCONNECTED; ESP_LOGW(TAG, "Lost MQTT Client connection!"); - this->start_dnslookup_(); + this->start_connect_(); } else { + ErrorCode ec = conn_->loop(); + if (ec != ErrorCode::OK) { + ESP_LOGW(TAG, "loop loop failed"); + } + if (!this->birth_message_.topic.empty() && !this->sent_birth_message_) { this->sent_birth_message_ = this->publish(this->birth_message_); } @@ -303,17 +200,18 @@ bool MQTTClientComponent::subscribe_(const char *topic, uint8_t qos) { if (!this->is_connected()) return false; - uint16_t ret = this->mqtt_client_.subscribe(topic, qos); - yield(); + Subscription sub{}; + sub.topic_filter = topic; + sub.requested_qos = static_cast(qos); + ErrorCode ec = this->conn_->subscribe({sub}); - if (ret != 0) { - ESP_LOGV(TAG, "subscribe(topic='%s')", topic); - } else { - delay(5); + if (ec != ErrorCode::OK) { ESP_LOGV(TAG, "Subscribe failed for topic='%s'. Will retry later.", topic); this->status_momentary_warning("subscribe", 1000); } - return ret != 0; + + ESP_LOGV(TAG, "subscribe(topic='%s')", topic); + return ec == ErrorCode::OK; } void MQTTClientComponent::resubscribe_subscription_(MQTTSubscription *sub) { if (sub->subscribed) @@ -361,15 +259,12 @@ void MQTTClientComponent::subscribe_json(const std::string &topic, const mqtt_js } void MQTTClientComponent::unsubscribe(const std::string &topic) { - uint16_t ret = this->mqtt_client_.unsubscribe(topic.c_str()); - yield(); - if (ret != 0) { - ESP_LOGV(TAG, "unsubscribe(topic='%s')", topic.c_str()); - } else { - delay(5); + ErrorCode ec = this->conn_->unsubscribe({topic}); + if (ec != ErrorCode::OK) { ESP_LOGV(TAG, "Unsubscribe failed for topic='%s'.", topic.c_str()); this->status_momentary_warning("unsubscribe", 1000); } + ESP_LOGV(TAG, "unsubscribe(topic='%s')", topic.c_str()); auto it = subscriptions_.begin(); while (it != subscriptions_.end()) { @@ -393,24 +288,22 @@ bool MQTTClientComponent::publish(const std::string &topic, const char *payload, return false; } bool logging_topic = topic == this->log_message_.topic; - uint16_t ret = this->mqtt_client_.publish(topic.c_str(), qos, retain, payload, payload_length); - delay(0); - if (ret == 0 && !logging_topic && this->is_connected()) { - delay(0); - ret = this->mqtt_client_.publish(topic.c_str(), qos, retain, payload, payload_length); - delay(0); + std::vector msg; + for (size_t i = 0; i < payload_length; i++) { + msg.push_back(static_cast(payload[i])); } + ErrorCode ec = this->conn_->publish(topic, std::move(msg), retain, static_cast(qos)); if (!logging_topic) { - if (ret != 0) { - ESP_LOGV(TAG, "Publish(topic='%s' payload='%s' retain=%d)", topic.c_str(), payload, retain); - } else { + if (ec != ErrorCode::OK) { ESP_LOGV(TAG, "Publish failed for topic='%s' (len=%u). will retry later..", topic.c_str(), payload_length); // NOLINT this->status_momentary_warning("publish", 1000); + } else { + ESP_LOGV(TAG, "Publish(topic='%s' payload='%s' retain=%d)", topic.c_str(), payload, retain); } } - return ret != 0; + return ec == ErrorCode::OK; } bool MQTTClientComponent::publish(const MQTTMessage &message) { @@ -500,7 +393,7 @@ bool MQTTClientComponent::is_log_message_enabled() const { return !this->log_mes void MQTTClientComponent::set_reboot_timeout(uint32_t reboot_timeout) { this->reboot_timeout_ = reboot_timeout; } void MQTTClientComponent::register_mqtt_component(MQTTComponent *component) { this->children_.push_back(component); } void MQTTClientComponent::set_log_level(int level) { this->log_level_ = level; } -void MQTTClientComponent::set_keep_alive(uint16_t keep_alive_s) { this->mqtt_client_.setKeepAlive(keep_alive_s); } +void MQTTClientComponent::set_keep_alive(uint16_t keep_alive_s) { conn_params_.keep_alive = keep_alive_s; } void MQTTClientComponent::set_log_message_template(MQTTMessage &&message) { this->log_message_ = std::move(message); } const MQTTDiscoveryInfo &MQTTClientComponent::get_discovery_info() const { return this->discovery_info_; } void MQTTClientComponent::set_topic_prefix(std::string topic_prefix) { this->topic_prefix_ = std::move(topic_prefix); } @@ -556,7 +449,7 @@ void MQTTClientComponent::on_shutdown() { this->publish(this->shutdown_message_); yield(); } - this->mqtt_client_.disconnect(true); + // this->mqtt_client_.disconnect(true); } #if ASYNC_TCP_SSL_ENABLED diff --git a/esphome/components/mqtt/mqtt_client.h b/esphome/components/mqtt/mqtt_client.h index a6a7025c6f..3a74fbac54 100644 --- a/esphome/components/mqtt/mqtt_client.h +++ b/esphome/components/mqtt/mqtt_client.h @@ -9,8 +9,7 @@ #include "esphome/core/log.h" #include "esphome/components/json/json_util.h" #include "esphome/components/network/ip_address.h" -#include -#include "lwip/ip_addr.h" +#include "mqtt_backend.h" namespace esphome { namespace mqtt { @@ -74,7 +73,6 @@ struct MQTTDiscoveryInfo { enum MQTTClientState { MQTT_CLIENT_DISCONNECTED = 0, - MQTT_CLIENT_RESOLVING_ADDRESS, MQTT_CLIENT_CONNECTING, MQTT_CLIENT_CONNECTED, }; @@ -116,22 +114,6 @@ class MQTTClientComponent : public Component { void disable_discovery(); bool is_discovery_enabled() const; -#if ASYNC_TCP_SSL_ENABLED - /** Add a SSL fingerprint to use for TCP SSL connections to the MQTT broker. - * - * To use this feature you first have to globally enable the `ASYNC_TCP_SSL_ENABLED` define flag. - * This function can be called multiple times and any certificate that matches any of the provided fingerprints - * will match. Calling this method will also automatically disable all non-ssl connections. - * - * @warning This is *not* secure and *not* how SSL is usually done. You'll have to add - * a separate fingerprint for every certificate you use. Additionally, the hashing - * algorithm used here due to the constraints of the MCU, SHA1, is known to be insecure. - * - * @param fingerprint The SSL fingerprint as a 20 value long std::array. - */ - void add_ssl_fingerprint(const std::array &fingerprint); -#endif - const Availability &get_availability(); /** Set the topic prefix that will be prepended to all topics together with "/". This will, in most cases, @@ -237,13 +219,6 @@ class MQTTClientComponent : public Component { protected: /// Reconnect to the MQTT broker if not already connected. void start_connect_(); - void start_dnslookup_(); - void check_dnslookup_(); -#if defined(USE_ESP8266) && LWIP_VERSION_MAJOR == 1 - static void dns_found_callback(const char *name, ip_addr_t *ipaddr, void *callback_arg); -#else - static void dns_found_callback(const char *name, const ip_addr_t *ipaddr, void *callback_arg); -#endif /// Re-calculate the availability property. void recalculate_availability_(); @@ -272,20 +247,18 @@ class MQTTClientComponent : public Component { }; std::string topic_prefix_{}; MQTTMessage log_message_; - std::string payload_buffer_; int log_level_{ESPHOME_LOG_LEVEL}; std::vector subscriptions_; - AsyncMqttClient mqtt_client_; MQTTClientState state_{MQTT_CLIENT_DISCONNECTED}; - network::IPAddress ip_; - bool dns_resolved_{false}; - bool dns_resolve_error_{false}; std::vector children_; uint32_t reboot_timeout_{300000}; uint32_t connect_begin_; uint32_t last_connected_{0}; - optional disconnect_reason_{}; + + std::unique_ptr conn_; + ConnectParams conn_params_{}; + MQTTSession sess_{}; }; extern MQTTClientComponent *global_mqtt_client; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/esphome/components/mqtt/packets.cpp b/esphome/components/mqtt/packets.cpp new file mode 100644 index 0000000000..66791ecf38 --- /dev/null +++ b/esphome/components/mqtt/packets.cpp @@ -0,0 +1,135 @@ +#include "packets.h" + +#ifdef USE_MQTT + +namespace esphome { +namespace mqtt { + +namespace util { + +ErrorCode encode_uint16(std::vector &target, uint16_t value) { + target.push_back((value >> 8) & 0xFF); + target.push_back((value >> 0) & 0xFF); + return ErrorCode::OK; +} +ErrorCode decode_uint16(Parser *parser, uint16_t *value) { + if (parser->size_left() < 2) + return ErrorCode::MALFORMED_PACKET; + *value = 0; + *value |= static_cast(parser->consume()) << 8; + *value |= static_cast(parser->consume()) << 0; + return ErrorCode::OK; +} +ErrorCode encode_bytes(std::vector &target, const std::vector &value) { + if (value.size() > 65535) + return ErrorCode::VALUE_TOO_LONG; + encode_uint16(target, value.size()); + target.insert(target.end(), value.begin(), value.end()); + return ErrorCode::OK; +} +ErrorCode decode_bytes(Parser *parser, std::vector *value) { + uint16_t len; + ErrorCode ec = decode_uint16(parser, &len); + if (ec != ErrorCode::OK) + return ec; + if (len > parser->size_left()) + return ErrorCode::MALFORMED_PACKET; + value->clear(); + value->reserve(len); + for (size_t i = 0; i < len; i++) + value->push_back(parser->consume()); + return ErrorCode::OK; +} +ErrorCode encode_utf8(std::vector &target, const std::string &value) { + if (value.size() > 65535) + return ErrorCode::VALUE_TOO_LONG; + encode_uint16(target, value.size()); + for (char c : value) + target.push_back(static_cast(c)); + return ErrorCode::OK; +} +ErrorCode decode_utf8(Parser *parser, std::string *value) { + uint16_t len; + ErrorCode ec = decode_uint16(parser, &len); + if (ec != ErrorCode::OK) + return ec; + if (len > parser->size_left()) + return ErrorCode::MALFORMED_PACKET; + value->clear(); + value->reserve(len); + for (size_t i = 0; i < len; i++) + value->push_back(static_cast(parser->consume())); + return ErrorCode::OK; +} +ErrorCode encode_varint(std::vector &target, size_t value) { + do { + uint8_t encbyte = value % 0x80; + value >>= 7; + if (value > 0) + encbyte |= 0x80; + target.push_back(encbyte); + } while (value > 0); + return ErrorCode::OK; +} + +ErrorCode encode_fixed_header( + std::vector &target, + PacketType packet_type, + uint8_t flags, + size_t remaining_length +) { + uint8_t head = 0; + head |= static_cast(packet_type) << 4; + head |= flags; + target.push_back(head); + return encode_varint(target, remaining_length); +} + +ErrorCode encode_packet( + std::vector &target, + PacketType packet_type, + uint8_t flags +) { + return encode_fixed_header(target, packet_type, flags, 0); +} + +ErrorCode encode_packet( + std::vector &target, + PacketType packet_type, + uint8_t flags, + const std::vector &variable_header +) { + ErrorCode ec = encode_fixed_header( + target, packet_type, flags, + variable_header.size() + ); + if (ec != ErrorCode::OK) + return ec; + target.insert(target.end(), variable_header.begin(), variable_header.end()); + return ErrorCode::OK; +} + +ErrorCode encode_packet( + std::vector &target, + PacketType packet_type, + uint8_t flags, + const std::vector &variable_header, + const std::vector &payload +) { + ErrorCode ec = encode_fixed_header( + target, packet_type, flags, + variable_header.size() + payload.size() + ); + if (ec != ErrorCode::OK) + return ec; + target.insert(target.end(), variable_header.begin(), variable_header.end()); + target.insert(target.end(), payload.begin(), payload.end()); + return ErrorCode::OK; +} + +} // namespace util + +} // namespace mqtt +} // namespace esphome + +#endif // USE_MQTT diff --git a/esphome/components/mqtt/packets.h b/esphome/components/mqtt/packets.h new file mode 100644 index 0000000000..f675cf6d01 --- /dev/null +++ b/esphome/components/mqtt/packets.h @@ -0,0 +1,653 @@ +#pragma once + +#include "esphome/core/defines.h" + +#ifdef USE_MQTT + +#include +#include +#include +#include "esphome/core/optional.h" + +namespace esphome { +namespace mqtt { + +enum class ErrorCode { + OK = 0, + VALUE_TOO_LONG = 1, + BAD_FLAGS = 2, + MALFORMED_PACKET = 3, + BAD_STATE = 4, + RESOLVE_ERROR = 5, + SOCKET_ERROR = 6, + TIMEOUT = 7, + IN_PROGRESS = 8, + WOULD_BLOCK = 9, + CONNECTION_CLOSED = 10, + UNEXPECTED = 11, + PROTOCOL_ERROR = 12, +}; + +enum class PacketType : uint8_t { + CONNECT = 1, + CONNACK = 2, + PUBLISH = 3, + PUBACK = 4, + PUBREC = 5, + PUBREL = 6, + PUBCOMP = 7, + SUBSCRIBE = 8, + SUBACK = 9, + UNSUBSCRIBE = 10, + UNSUBACK = 11, + PINGREQ = 12, + PINGRESP = 13, + DISCONNECT = 14, +}; + +namespace util { +class Parser { + public: + Parser(const uint8_t *data, size_t len) : data_(data), len_(len) {} + + size_t size_left() const { + return len_ - at_; + } + + uint8_t consume() { + return data_[at_++]; + } + + void consume(size_t amount) { + at_ += amount; + } + + private: + const uint8_t *data_; + size_t len_; + size_t at_ = 0; +}; + +ErrorCode encode_uint16(std::vector &target, uint16_t value); +ErrorCode decode_uint16(Parser *parser, uint16_t *value); +ErrorCode encode_bytes(std::vector &target, const std::vector &value); +ErrorCode decode_bytes(Parser *parser, std::vector *value); +ErrorCode encode_utf8(std::vector &target, const std::string &value); +ErrorCode decode_utf8(Parser *parser, std::string *value); +ErrorCode encode_varint(std::vector &target, size_t value); +ErrorCode encode_fixed_header( + std::vector &target, + PacketType packet_type, + uint8_t flags, + size_t remaining_length +); +ErrorCode encode_packet( + std::vector &target, + PacketType packet_type, + uint8_t flags +); +ErrorCode encode_packet( + std::vector &target, + PacketType packet_type, + uint8_t flags, + const std::vector &variable_header +); +ErrorCode encode_packet( + std::vector &target, + PacketType packet_type, + uint8_t flags, + const std::vector &variable_header, + const std::vector &payload +); + +} // namespace util + +class MQTTPacket { + public: + virtual ErrorCode encode(std::vector &target) const = 0; + virtual ErrorCode decode(uint8_t flags, util::Parser parser) = 0; +}; + +enum class QOSLevel : uint8_t { + QOS0 = 0, + QOS1 = 1, + QOS2 = 2, +}; + +class ConnectPacket : public MQTTPacket { + public: + // 3.1 + std::string client_id; + optional username; + optional> password; + std::string will_topic; + std::vector will_message; + QOSLevel will_qos = QOSLevel::QOS0; + bool will_retain = false; + bool clean_session = true; + uint8_t protocol_level = 4; + uint16_t keep_alive; + + ErrorCode encode(std::vector &target) const final { + uint8_t connect_flags = 0; + if (username.has_value()) + connect_flags |= 0x80; + if (password.has_value()) + connect_flags |= 0x40; + if (will_retain) + connect_flags |= 0x20; + connect_flags |= static_cast(will_qos) << 3; + if (!will_topic.empty()) + connect_flags |= 0x04; + if (clean_session) + connect_flags |= 0x02; + std::vector variable_header; + variable_header.push_back(0x00); + variable_header.push_back(0x04); + variable_header.push_back('M'); + variable_header.push_back('Q'); + variable_header.push_back('T'); + variable_header.push_back('T'); + variable_header.push_back(protocol_level); + variable_header.push_back(connect_flags); + ErrorCode ec = util::encode_uint16(variable_header, keep_alive); + if (ec != ErrorCode::OK) + return ec; + + std::vector payload; + ec = util::encode_utf8(payload, client_id); + if (ec != ErrorCode::OK) + return ec; + if (!will_topic.empty()) { + ec = util::encode_utf8(payload, will_topic); + if (ec != ErrorCode::OK) + return ec; + ec = util::encode_bytes(payload, will_message); + if (ec != ErrorCode::OK) + return ec; + } + if (username.has_value()) { + ec = util::encode_utf8(payload, *username); + if (ec != ErrorCode::OK) + return ec; + } + if (password.has_value()) { + ec = util::encode_bytes(payload, *password); + if (ec != ErrorCode::OK) + return ec; + } + + return util::encode_packet( + target, PacketType::CONNECT, 0, + variable_header, payload + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 0) + return ErrorCode::BAD_FLAGS; + if ( + parser.size_left() < 10 + || parser.consume() != '\x00' + || parser.consume() != '\x04' + || parser.consume() != 'M' + || parser.consume() != 'Q' + || parser.consume() != 'T' + || parser.consume() != 'T' + ) + return ErrorCode::MALFORMED_PACKET; + protocol_level = parser.consume(); + uint8_t connect_flags = parser.consume(); + bool username_flag = connect_flags & 0x80; + bool password_flag = connect_flags & 0x40; + will_retain = connect_flags & 0x20; + will_qos = static_cast((connect_flags >> 3) & 3); + bool will_flag = connect_flags & 0x04; + clean_session = connect_flags & 0x02; + if ((flags & 1) != 0) + return ErrorCode::MALFORMED_PACKET; + ErrorCode ec = util::decode_uint16(&parser, &keep_alive); + if (ec != ErrorCode::OK) + return ec; + + ec = util::decode_utf8(&parser, &client_id); + if (ec != ErrorCode::OK) + return ec; + + will_topic.clear(); + will_message.clear(); + if (will_flag) { + ec = util::decode_utf8(&parser, &will_topic); + if (ec != ErrorCode::OK) + return ec; + ec = util::decode_bytes(&parser, &will_message); + if (ec != ErrorCode::OK) + return ec; + } + + username.reset(); + if (username_flag) { + username = {""}; + ec = util::decode_utf8(&parser, &(*username)); + if (ec != ErrorCode::OK) + return ec; + } + password.reset(); + if (password_flag) { + password = std::vector{}; + ec = util::decode_bytes(&parser, &(*password)); + if (ec != ErrorCode::OK) + return ec; + } + + return ErrorCode::OK; + } +}; + +enum class ConnectReturnCode : uint8_t { + ACCEPTED = 0, + UNACCEPTABLE_PROTOCOL_VERSION = 1, + IDENTIFIER_REJECTED = 2, + SERVER_UNAVAILABLE = 3, + BAD_USER_NAME_OR_PASSWORD = 4, + NOT_AUTHORIZED = 5, +}; + +class ConnackPacket : public MQTTPacket { + public: + // 3.2 + bool session_present; + ConnectReturnCode connect_return_code; + + ErrorCode encode(std::vector &target) const final { + std::vector variable_header; + variable_header.push_back(static_cast(session_present)); + variable_header.push_back(static_cast(connect_return_code)); + return util::encode_packet( + target, PacketType::CONNACK, 0, + variable_header + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 0) + return ErrorCode::BAD_FLAGS; + if (parser.size_left() != 2) + return ErrorCode::MALFORMED_PACKET; + session_present = parser.consume() & 1; + connect_return_code = static_cast(parser.consume()); + return ErrorCode::OK; + } +}; + +class PublishPacket : public MQTTPacket { + public: + // 3.3 + std::string topic; + std::vector message; + bool dup = false; + QOSLevel qos = QOSLevel::QOS0; + bool retain = false; + optional packet_identifier; + + ErrorCode encode(std::vector &target) const final { + uint8_t flags = 0; + if (dup) + flags |= 0x08; + flags |= static_cast(qos) << 1; + if (retain) + flags |= 0x01; + std::vector variable_header; + ErrorCode ec = util::encode_utf8(variable_header, topic); + if (ec != ErrorCode::OK) + return ec; + if (packet_identifier.has_value()) { + ec = util::encode_uint16(variable_header, *packet_identifier); + if (ec != ErrorCode::OK) + return ec; + } + + return util::encode_packet( + target, PacketType::PUBLISH, flags, + variable_header, message + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + dup = flags & 0x08; + qos = static_cast((flags >> 1) & 3); + retain = flags & 0x01; + ErrorCode ec = util::decode_utf8(&parser, &topic); + if (ec != ErrorCode::OK) + return ec; + if (qos == QOSLevel::QOS1 || qos == QOSLevel::QOS2) { + packet_identifier = 0; + ec = util::decode_uint16(&parser, &(*packet_identifier)); + if (ec != ErrorCode::OK) + return ec; + } else { + packet_identifier.reset(); + } + message.clear(); + message.reserve(parser.size_left()); + while (parser.size_left()) + message.push_back(parser.consume()); + return ErrorCode::OK; + } +}; + +class PubackPacket : public MQTTPacket { + public: + // 3.4 + uint16_t packet_identifier; + + ErrorCode encode(std::vector &target) const final { + std::vector variable_header; + ErrorCode ec = util::encode_uint16(variable_header, packet_identifier); + if (ec != ErrorCode::OK) + return ec; + return util::encode_packet( + target, PacketType::PUBACK, 0, + variable_header + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 0) + return ErrorCode::BAD_FLAGS; + if (parser.size_left() != 2) + return ErrorCode::MALFORMED_PACKET; + return util::decode_uint16(&parser, &packet_identifier); + } +}; + +class PubrecPacket : public MQTTPacket { + public: + // 3.5 + uint16_t packet_identifier; + + ErrorCode encode(std::vector &target) const final { + std::vector variable_header; + ErrorCode ec = util::encode_uint16(variable_header, packet_identifier); + if (ec != ErrorCode::OK) + return ec; + return util::encode_packet( + target, PacketType::PUBREC, 0, + variable_header + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 0) + return ErrorCode::BAD_FLAGS; + if (parser.size_left() != 2) + return ErrorCode::MALFORMED_PACKET; + return util::decode_uint16(&parser, &packet_identifier); + } +}; + +class PubrelPacket : public MQTTPacket { + public: + // 3.6 + uint16_t packet_identifier; + + ErrorCode encode(std::vector &target) const final { + std::vector variable_header; + ErrorCode ec = util::encode_uint16(variable_header, packet_identifier); + if (ec != ErrorCode::OK) + return ec; + return util::encode_packet( + target, PacketType::PUBREL, 2, + variable_header + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 2) + return ErrorCode::BAD_FLAGS; + if (parser.size_left() != 2) + return ErrorCode::MALFORMED_PACKET; + return util::decode_uint16(&parser, &packet_identifier); + } +}; + +class PubcompPacket : public MQTTPacket { + public: + // 3.7 + uint16_t packet_identifier; + + ErrorCode encode(std::vector &target) const final { + std::vector variable_header; + ErrorCode ec = util::encode_uint16(variable_header, packet_identifier); + if (ec != ErrorCode::OK) + return ec; + return util::encode_packet( + target, PacketType::PUBCOMP, 2, + variable_header + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 2) + return ErrorCode::BAD_FLAGS; + if (parser.size_left() != 2) + return ErrorCode::MALFORMED_PACKET; + return util::decode_uint16(&parser, &packet_identifier); + } +}; + +struct Subscription { + std::string topic_filter; + QOSLevel requested_qos = QOSLevel::QOS0; +}; + +class SubscribePacket : public MQTTPacket { + public: + // 3.8 + uint16_t packet_identifier; + std::vector subscriptions; + + ErrorCode encode(std::vector &target) const final { + std::vector variable_header; + ErrorCode ec = util::encode_uint16(variable_header, packet_identifier); + if (ec != ErrorCode::OK) + return ec; + std::vector payload; + for (const auto &sub : subscriptions) { + ec = util::encode_utf8(payload, sub.topic_filter); + if (ec != ErrorCode::OK) + return ec; + payload.push_back(static_cast(sub.requested_qos)); + } + return util::encode_packet( + target, PacketType::SUBSCRIBE, 2, + variable_header, payload + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 2) + return ErrorCode::BAD_FLAGS; + + ErrorCode ec = util::decode_uint16(&parser, &packet_identifier); + if (ec != ErrorCode::OK) + return ec; + subscriptions.clear(); + while (parser.size_left()) { + Subscription sub{}; + ec = util::decode_utf8(&parser, &sub.topic_filter); + if (ec != ErrorCode::OK) + return ec; + if (parser.size_left() < 1) + return ErrorCode::MALFORMED_PACKET; + sub.requested_qos = static_cast(parser.consume()); + subscriptions.push_back(sub); + } + return ErrorCode::OK; + } +}; + +enum class SubackReturnCode : uint8_t { + SUCCESS_MAX_QOS0 = 0x00, + SUCCESS_MAX_QOS1 = 0x01, + SUCCESS_MAX_QOS2 = 0x02, + FAILURE = 0x80, +}; + +class SubackPacket : public MQTTPacket { + public: + // 3.9 + uint16_t packet_identifier; + std::vector return_codes; + + ErrorCode encode(std::vector &target) const final { + std::vector variable_header; + ErrorCode ec = util::encode_uint16(variable_header, packet_identifier); + std::vector payload; + payload.reserve(return_codes.size()); + for (SubackReturnCode rc : return_codes) { + payload.push_back(static_cast(rc)); + } + return util::encode_packet( + target, PacketType::SUBACK, 0, + variable_header, payload + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 2) + return ErrorCode::BAD_FLAGS; + ErrorCode ec = util::decode_uint16(&parser, &packet_identifier); + if (ec != ErrorCode::OK) + return ec; + return_codes.clear(); + return_codes.reserve(parser.size_left()); + while (parser.size_left()) { + return_codes.push_back(static_cast(parser.consume())); + } + return ErrorCode::OK; + } +}; + +class UnsubscribePacket : public MQTTPacket { + public: + // 3.10 + uint16_t packet_identifier; + std::vector topic_filters; + + ErrorCode encode(std::vector &target) const final { + std::vector variable_header; + ErrorCode ec = util::encode_uint16(variable_header, packet_identifier); + if (ec != ErrorCode::OK) + return ec; + std::vector payload; + for (const auto &topic : topic_filters) { + ec = util::encode_utf8(payload, topic); + if (ec != ErrorCode::OK) + return ec; + } + return util::encode_packet( + target, PacketType::UNSUBACK, 2, + variable_header, payload + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 2) + return ErrorCode::BAD_FLAGS; + + ErrorCode ec = util::decode_uint16(&parser, &packet_identifier); + if (ec != ErrorCode::OK) + return ec; + topic_filters.clear(); + while (parser.size_left()) { + std::string topic; + ec = util::decode_utf8(&parser, &topic); + if (ec != ErrorCode::OK) + return ec; + topic_filters.push_back(topic); + } + return ErrorCode::OK; + } +}; + +class UnsubackPacket : public MQTTPacket { + public: + // 3.11 + uint16_t packet_identifier; + + ErrorCode encode(std::vector &target) const final { + std::vector variable_header; + ErrorCode ec = util::encode_uint16(variable_header, packet_identifier); + if (ec != ErrorCode::OK) + return ec; + return util::encode_packet( + target, PacketType::UNSUBACK, 0, + variable_header + ); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 0) + return ErrorCode::BAD_FLAGS; + if (parser.size_left() != 2) + return ErrorCode::MALFORMED_PACKET; + return util::decode_uint16(&parser, &packet_identifier); + } +}; + +class PingreqPacket : public MQTTPacket { + public: + // 3.12 + + ErrorCode encode(std::vector &target) const final { + return util::encode_packet(target, PacketType::PINGREQ, 0); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 0) + return ErrorCode::BAD_FLAGS; + if (parser.size_left() != 0) + return ErrorCode::MALFORMED_PACKET; + return ErrorCode::OK; + } +}; + +class PingrespPacket : public MQTTPacket { + public: + // 3.13 + + ErrorCode encode(std::vector &target) const final { + return util::encode_packet(target, PacketType::PINGRESP, 0); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 0) + return ErrorCode::BAD_FLAGS; + if (parser.size_left() != 0) + return ErrorCode::MALFORMED_PACKET; + return ErrorCode::OK; + } +}; + +class DisconnectPacket : public MQTTPacket { + public: + // 3.14 + + ErrorCode encode(std::vector &target) const final { + return util::encode_packet(target, PacketType::DISCONNECT, 0); + } + + ErrorCode decode(uint8_t flags, util::Parser parser) override final { + if (flags != 0) + return ErrorCode::BAD_FLAGS; + if (parser.size_left() != 0) + return ErrorCode::MALFORMED_PACKET; + return ErrorCode::OK; + } +}; + +} // namespace mqtt +} // namespace esphome + +#endif // USE_MQTT