Initial native MQTT commit

This commit is contained in:
Otto winter 2022-02-03 14:14:01 +01:00
parent a81fc6e85d
commit 3647e32692
No known key found for this signature in database
GPG key ID: 5B48AF485DF70D0E
6 changed files with 1572 additions and 214 deletions

View file

@ -0,0 +1,561 @@
#include "mqtt_backend.h"
#ifdef USE_MQTT
#include "esphome/core/log.h"
#include "esphome/core/hal.h"
#include "esphome/core/helpers.h"
namespace esphome {
namespace mqtt {
static const char *TAG = "mqtt.backend";
ErrorCode util::ConnectionEstablisher::init(const std::string &host, uint16_t port, uint32_t timeout) {
if (state_ != State::UNINIT) {
enter_error_();
ESP_LOGD(TAG, "conn init bad state");
return ErrorCode::BAD_STATE;
}
struct addrinfo hints;
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
std::string port_s = to_string(port);
getaddrinfo_ = socket::getaddrinfo_async(host.c_str(), port_s.c_str(), &hints);
if (!getaddrinfo_) {
enter_error_();
return ErrorCode::RESOLVE_ERROR;
}
start_ = millis();
timeout_ = timeout;
state_ = State::RESOLVING;
return ErrorCode::OK;
}
void util::ConnectionEstablisher::enter_error_() {
state_ = State::ERROR;
getaddrinfo_.reset();
socket_.reset();
}
ErrorCode util::ConnectionEstablisher::loop() {
ErrorCode ec;
uint32_t now = millis();
switch (state_) {
case State::UNINIT: {
enter_error_();
state_ = State::ERROR;
ESP_LOGD(TAG, "conn uninit bad state");
return ErrorCode::BAD_STATE;
}
case State::RESOLVING: {
if (getaddrinfo_->completed()) {
struct addrinfo *res;
int r = getaddrinfo_->fetch_result(&res);
if (r != 0) {
enter_error_();
ESP_LOGW(TAG, "Address resolve failed with error %s", gai_strerror(r));
return ErrorCode::RESOLVE_ERROR;
}
if (res == nullptr) {
enter_error_();
ESP_LOGW(TAG, "Address resolve returned no results");
return ErrorCode::RESOLVE_ERROR;
}
ESP_LOGD(TAG, "address resolved!");
socket_ = socket::socket(res->ai_family, res->ai_socktype, res->ai_protocol);
if (!socket_) {
freeaddrinfo(res);
enter_error_();
ESP_LOGW(TAG, "Socket creation failed with error %s", strerror(errno));
return ErrorCode::SOCKET_ERROR;
}
r = socket_->setblocking(false);
if (r != 0) {
enter_error_();
ESP_LOGV(TAG, "Setting nonblocking socket failed with error %s", strerror(errno));
return ErrorCode::SOCKET_ERROR;
}
r = socket_->connect(res->ai_addr, res->ai_addrlen);
freeaddrinfo(res);
if (r == 0) {
// connection established immediately
getaddrinfo_.reset();
state_ = State::CONNECTED;
return ErrorCode::OK;
} else if (errno == EINPROGRESS) {
getaddrinfo_.reset();
state_ = State::CONNECTING;
} else {
enter_error_();
ESP_LOGW(TAG, "Socket connect failed with error %s", strerror(errno));
return ErrorCode::SOCKET_ERROR;
}
}
if (now - start_ >= timeout_) {
enter_error_();
ESP_LOGW(TAG, "Timeout resolving address");
return ErrorCode::TIMEOUT;
}
return ErrorCode::IN_PROGRESS;
}
case State::CONNECTING: {
int r = socket_->connect_finished();
if (r == 0) {
// connection established
state_ = State::CONNECTED;
return ErrorCode::OK;
} else if (errno == EINPROGRESS) {
// not established yet
if (now - start_ >= timeout_) {
enter_error_();
ESP_LOGW(TAG, "Timeout connecting to address");
return ErrorCode::TIMEOUT;
}
return ErrorCode::IN_PROGRESS;
} else {
enter_error_();
ESP_LOGW(TAG, "Socket connect failed with error %s", strerror(errno));
return ErrorCode::SOCKET_ERROR;
}
}
case State::CONNECTED: {
return ErrorCode::OK;
}
case State::FINISHED:
case State::ERROR:
default: {
return ErrorCode::BAD_STATE;
}
}
}
std::unique_ptr<socket::Socket> util::ConnectionEstablisher::extract_socket() {
if (state_ != State::CONNECTED)
return nullptr;
state_ = State::FINISHED;
return std::move(socket_);
}
ErrorCode util::BufferedWriter::write(const std::unique_ptr<socket::Socket> &sock, const uint8_t *data, size_t len, bool do_buffer) {
if (len == 0)
return ErrorCode::OK;
ErrorCode ec;
if (!tx_buf_.empty()) {
// try to empty tx_buf_ first
ec = try_drain(sock);
if (ec != ErrorCode::OK && ec != ErrorCode::WOULD_BLOCK)
return ec;
}
if (!tx_buf_.empty()) {
// tx buf not empty, can't write now because then stream would be inconsistent
if (!do_buffer)
return ErrorCode::WOULD_BLOCK;
tx_buf_.insert(tx_buf_.end(), data, data + len);
return ErrorCode::OK;
}
ssize_t sent = sock->write(data, len);
if (sent == 0 || (sent == -1 && (errno == EWOULDBLOCK || errno == EAGAIN))) {
// operation would block, add to tx_buf if buffering
if (!do_buffer)
return ErrorCode::WOULD_BLOCK;
tx_buf_.insert(tx_buf_.end(), data, data + len);
return ErrorCode::OK;
} else if (sent == -1) {
// an error occured
ESP_LOGV(TAG, "Socket write failed with errno %d", errno);
return ErrorCode::SOCKET_ERROR;
} else if ((size_t) sent != len) {
// partially sent, add end to tx_buf (even if not set to buffering, to prevent
// partial packet transmission)
tx_buf_.insert(tx_buf_.end(), data + sent, data + len);
return ErrorCode::OK;
}
// fully sent
return ErrorCode::OK;
}
ErrorCode util::BufferedWriter::try_drain(const std::unique_ptr<socket::Socket> &sock) {
// try send from tx_buf
while (!tx_buf_.empty()) {
ssize_t sent = sock->write(tx_buf_.data(), tx_buf_.size());
if (sent == 0 || (sent == -1 && (errno = EWOULDBLOCK || errno == EAGAIN))) {
break;
} else if (sent == -1) {
ESP_LOGV(TAG, "Socket write failed with errno %d", errno);
return ErrorCode::SOCKET_ERROR;
}
// TODO: inefficient if multiple packets in txbuf
// replace with deque of buffers
tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent);
}
return ErrorCode::OK;
}
ErrorCode MQTTConnection::init(ConnectParams *params, MQTTSession *session) {
if (state_ != State::UNINIT) {
enter_error_();
return ErrorCode::BAD_STATE;
}
params_ = params;
session_ = session;
connection_establisher_ = make_unique<util::ConnectionEstablisher>();
ErrorCode ec = connection_establisher_->init(params_->host, params_->port, 5000);
if (ec != ErrorCode::OK) {
enter_error_();
return ec;
}
state_ = State::CONNECTING;
return ErrorCode::OK;
}
ErrorCode MQTTConnection::loop() {
ErrorCode ec;
switch (state_) {
case State::CONNECTING: {
ec = connection_establisher_->loop();
if (ec == ErrorCode::OK) {
// connection established
socket_ = connection_establisher_->extract_socket();
ConnectPacket packet{};
packet.client_id = params_->client_id;
packet.username = params_->username;
packet.password = params_->password;
packet.will_topic = params_->will_topic;
packet.will_message = params_->will_message;
packet.will_qos = params_->will_qos;
packet.will_retain = params_->will_retain;
packet.clean_session = true;
packet.keep_alive = params_->keep_alive;
std::vector<uint8_t> packet_enc;
ec = packet.encode(packet_enc);
if (ec != ErrorCode::OK) {
enter_error_();
return ec;
}
ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), true);
if (ec != ErrorCode::OK) {
enter_error_();
return ec;
}
state_ = State::WAIT_CONNACK;
connection_establisher_.reset();
} else if (ec != ErrorCode::IN_PROGRESS) {
enter_error_();
return ec;
}
return ErrorCode::OK;
}
case State::WAIT_CONNACK:
case State::CONNECTED: {
ec = writer_.try_drain(socket_);
if (ec != ErrorCode::OK) {
enter_error_();
return ec;
}
ec = read_packet_();
if (ec != ErrorCode::OK && ec != ErrorCode::WOULD_BLOCK) {
enter_error_();
return ec;
}
return ErrorCode::OK;
}
case State::UNINIT:
case State::DISCONNECTED:
case State::ERROR:
default: {
enter_error_();
ESP_LOGD(TAG, "bad state %d", (int) state_);
return ErrorCode::BAD_STATE;
}
}
}
void MQTTConnection::enter_error_() {
ESP_LOGD(TAG, "enter_error");
connection_establisher_.reset();
socket_.reset();
rx_header_buf_ = {};
rx_buf_ = {};
writer_.stop();
state_ = State::ERROR;
}
ErrorCode MQTTConnection::read_packet_() {
if (state_ != State::CONNECTED && state_ != State::WAIT_CONNACK) {
enter_error_();
ESP_LOGD(TAG, "read_packet_ bad state");
return ErrorCode::BAD_STATE;
}
ErrorCode ec;
if (!rx_header_parsed_) {
while (true) {
uint8_t v;
ssize_t received = socket_->read(&v, 1);
if (received == -1 && (errno == EWOULDBLOCK || errno == EAGAIN)) {
// would block
return ErrorCode::WOULD_BLOCK;
} else if (received == -1) {
// error
enter_error_();
ESP_LOGV(TAG, "Socket read failed with errno %d", errno);
return ErrorCode::SOCKET_ERROR;
} else if (received == 0) {
// EOF
enter_error_();
ESP_LOGV(TAG, "Socket EOF");
return ErrorCode::SOCKET_ERROR;
}
rx_header_buf_.push_back(v);
// try parse buf
if (rx_header_buf_.size() == 1)
continue;
rx_header_parsed_type_ = (rx_header_buf_[0] >> 4) & 0x0F;
rx_header_parsed_flags_ = (rx_header_buf_[0] >> 0) & 0x0F;
size_t multiplier = 1, value = 0;
size_t i = 1;
uint8_t enc;
bool parsed = true;
do {
if (i >= rx_header_buf_.size()) {
// not enough data yet
parsed = false;
break;
}
enc = rx_header_buf_[i];
value += (enc & 0x7F) * multiplier;
multiplier <<= 7;
} while (enc & 0x80);
if (!parsed)
continue;
rx_header_parsed_ = true;
rx_header_parsed_len_ = value;
}
}
// header reading done
// reserve space for body
if (rx_buf_.size() != rx_header_parsed_len_) {
rx_buf_.resize(rx_header_parsed_len_);
}
if (rx_buf_len_ < rx_header_parsed_len_) {
// more data to read
size_t to_read = rx_header_parsed_len_ - rx_buf_len_;
ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read);
if (received == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
return ErrorCode::WOULD_BLOCK;
}
enter_error_();
ESP_LOGV(TAG, "Socket read failed with errno %d", errno);
return ErrorCode::SOCKET_ERROR;
} else if (received == 0) {
enter_error_();
ESP_LOGD(TAG, "Connection closed");
return ErrorCode::CONNECTION_CLOSED;
}
rx_buf_len_ += received;
if ((size_t) received != to_read) {
// not all read
return ErrorCode::WOULD_BLOCK;
}
}
// body reading done
ec = handle_packet_(rx_header_parsed_type_, rx_header_parsed_flags_, rx_buf_.data(), rx_buf_.size());
// prepare for next packet
rx_header_parsed_ = false;
return ec;
}
ErrorCode MQTTConnection::handle_packet_(uint8_t packet_type, uint8_t flags, const uint8_t *data, size_t len) {
util::Parser parser(data, len);
ErrorCode ec;
switch (static_cast<PacketType>(packet_type)) {
case PacketType::CONNACK: {
if (state_ != State::WAIT_CONNACK) {
enter_error_();
ESP_LOGV(TAG, "Bad state for connack %d", (int) state_);
return ErrorCode::BAD_STATE;
}
ConnackPacket packet{};
ec = packet.decode(flags, parser);
if (ec != ErrorCode::OK) {
enter_error_();
ESP_LOGV(TAG, "Error decoding connack packet %d", (int) ec);
return ec;
}
if (packet.connect_return_code != ConnectReturnCode::ACCEPTED) {
const char *reason;
switch (packet.connect_return_code) {
case ConnectReturnCode::UNACCEPTABLE_PROTOCOL_VERSION:
reason = "unacceptable protocol version";
break;
case ConnectReturnCode::IDENTIFIER_REJECTED:
reason = "identifier rejected";
break;
case ConnectReturnCode::SERVER_UNAVAILABLE:
reason = "server unavailable";
break;
case ConnectReturnCode::BAD_USER_NAME_OR_PASSWORD:
reason = "bad user name or password";
break;
case ConnectReturnCode::NOT_AUTHORIZED:
reason = "not authorized";
break;
default:
reason = "unknown";
break;
}
enter_error_();
ESP_LOGW(TAG, "Connect failed: %s", reason);
return ErrorCode::PROTOCOL_ERROR;
}
ESP_LOGD(TAG, "Connected!");
state_ = State::CONNECTED;
return ErrorCode::OK;
}
case PacketType::PUBLISH: {
if (state_ != State::CONNECTED) {
enter_error_();
ESP_LOGV(TAG, "Bad state for publish %d", (int) state_);
return ErrorCode::BAD_STATE;
}
return ErrorCode::OK;
}
case PacketType::PUBACK:
case PacketType::PUBREC:
case PacketType::PUBREL:
case PacketType::PUBCOMP:
case PacketType::SUBACK:
case PacketType::UNSUBACK: {
// TODO
ESP_LOGD(TAG, "Received packet with type %u", packet_type);
return ErrorCode::OK;
}
case PacketType::PINGRESP: {
// TODO rx timer
return ErrorCode::OK;
}
case PacketType::CONNECT:
case PacketType::DISCONNECT:
case PacketType::SUBSCRIBE:
case PacketType::UNSUBSCRIBE:
case PacketType::PINGREQ:
default: {
enter_error_();
ESP_LOGW(TAG, "Received unknown packet type %u", packet_type);
return ErrorCode::UNEXPECTED;
}
}
}
ErrorCode MQTTConnection::publish(std::string topic, std::vector<uint8_t> message, bool retain, QOSLevel qos) {
PublishPacket packet{};
packet.topic = std::move(topic);
packet.message = std::move(message);
packet.retain = retain;
packet.qos = qos;
if (packet.qos != QOSLevel::QOS0) {
packet.packet_identifier = session_->create_packet_id();
}
packet.dup = false;
std::vector<uint8_t> packet_enc;
ErrorCode ec = packet.encode(packet_enc);
if (ec != ErrorCode::OK)
return ec;
ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), qos != QOSLevel::QOS0);
if (ec != ErrorCode::OK && ec != ErrorCode::WOULD_BLOCK) {
enter_error_();
ESP_LOGV(TAG, "publish write failed");
return ec;
}
return ec;
}
ErrorCode MQTTConnection::subscribe(std::vector<Subscription> subscriptions) {
SubscribePacket packet{};
packet.subscriptions = std::move(subscriptions);
packet.packet_identifier = session_->create_packet_id();
std::vector<uint8_t> packet_enc;
ErrorCode ec = packet.encode(packet_enc);
if (ec != ErrorCode::OK)
return ec;
ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), true);
if (ec != ErrorCode::OK) {
enter_error_();
ESP_LOGV(TAG, "subscribe write failed");
return ec;
}
return ec;
}
ErrorCode MQTTConnection::unsubscribe(std::vector<std::string> topic_filters) {
UnsubscribePacket packet{};
packet.topic_filters = std::move(topic_filters);
packet.packet_identifier = session_->create_packet_id();
std::vector<uint8_t> packet_enc;
ErrorCode ec = packet.encode(packet_enc);
if (ec != ErrorCode::OK)
return ec;
ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), true);
if (ec != ErrorCode::OK) {
enter_error_();
ESP_LOGV(TAG, "unsubscribe write failed");
return ec;
}
return ec;
}
} // namespace mqtt
} // namespace esphome
#endif // USE_MQTT

View file

@ -0,0 +1,143 @@
#pragma once
#include "esphome/core/defines.h"
#ifdef USE_MQTT
#include "packets.h"
#include "esphome/components/socket/socket.h"
#include "esphome/components/socket/getaddrinfo.h"
#include <memory>
#include <set>
namespace esphome {
namespace mqtt {
namespace util {
class ConnectionEstablisher {
public:
ErrorCode init(const std::string &host, uint16_t port, uint32_t timeout);
ErrorCode loop();
// Should only be called when loop() returns OK, is guaranteed to succeed
std::unique_ptr<socket::Socket> extract_socket();
protected:
void enter_error_();
std::unique_ptr<socket::Socket> socket_;
std::unique_ptr<socket::GetaddrinfoFuture> getaddrinfo_;
uint32_t timeout_;
uint32_t start_;
enum class State {
UNINIT = 0,
RESOLVING = 1,
CONNECTING = 2,
CONNECTED = 3,
FINISHED = 4,
ERROR = 5,
} state_ = State::UNINIT;
};
class BufferedWriter {
public:
ErrorCode write(const std::unique_ptr<socket::Socket> &sock, const uint8_t *data, size_t len,
bool do_buffer);
ErrorCode try_drain(const std::unique_ptr<socket::Socket> &sock);
void stop() {
tx_buf_ = {};
}
protected:
std::vector<uint8_t> tx_buf_;
};
} // namespace util
struct ConnectParams {
std::string host;
uint16_t port;
std::string client_id;
optional<std::string> username;
optional<std::vector<uint8_t>> password;
std::string will_topic;
std::vector<uint8_t> will_message;
QOSLevel will_qos;
bool will_retain;
uint16_t keep_alive;
};
class MQTTSession {
public:
bool get_has_session() const { return has_session_; }
void set_has_session(bool has_session) { has_session_ = has_session; }
void clean_session() {
packet_id_counter_ = 0;
used_packet_ids_.clear();
}
uint16_t create_packet_id() {
while (true) {
packet_id_counter_++;
if (packet_id_counter_ == 0 || used_packet_ids_.count(packet_id_counter_) > 0) {
continue;
}
used_packet_ids_.insert(packet_id_counter_);
return packet_id_counter_;
}
}
void return_packet_id(uint16_t packet_id) {
used_packet_ids_.erase(packet_id);
}
protected:
bool has_session_ = false;
uint32_t packet_id_counter_ = 0;
std::set<uint16_t> used_packet_ids_;
};
class MQTTConnection {
public:
ErrorCode init(ConnectParams *params, MQTTSession *session);
ErrorCode loop();
bool is_connected() { return state_ == State::CONNECTED; }
ErrorCode publish(std::string topic, std::vector<uint8_t> message, bool retain, QOSLevel qos);
ErrorCode subscribe(std::vector<Subscription> subscriptions);
ErrorCode unsubscribe(std::vector<std::string> topic_filters);
protected:
void enter_error_();
ErrorCode read_packet_();
ErrorCode handle_packet_(uint8_t packet_type, uint8_t flags, const uint8_t *data, size_t len);
ConnectParams *params_;
MQTTSession *session_;
std::unique_ptr<util::ConnectionEstablisher> connection_establisher_;
std::unique_ptr<socket::Socket> socket_;
std::vector<uint8_t> rx_header_buf_;
bool rx_header_parsed_ = false;
uint8_t rx_header_parsed_type_ = 0;
uint8_t rx_header_parsed_flags_ = 0;
size_t rx_header_parsed_len_ = 0;
std::vector<uint8_t> rx_buf_;
size_t rx_buf_len_ = 0;
util::BufferedWriter writer_;
enum class State {
UNINIT = 0,
CONNECTING = 1,
WAIT_CONNACK = 2,
CONNECTED = 3,
DISCONNECTED = 4,
ERROR = 5,
} state_ = State::UNINIT;
};
} // namespace mqtt
} // namespace esphome
#endif // USE_MQTT

View file

@ -28,7 +28,7 @@ MQTTClientComponent::MQTTClientComponent() {
// Connection // Connection
void MQTTClientComponent::setup() { void MQTTClientComponent::setup() {
ESP_LOGCONFIG(TAG, "Setting up MQTT..."); ESP_LOGCONFIG(TAG, "Setting up MQTT...");
this->mqtt_client_.onMessage([this](char const *topic, char *payload, AsyncMqttClientMessageProperties properties, /*this->mqtt_client_.onMessage([this](char const *topic, char *payload, AsyncMqttClientMessageProperties properties,
size_t len, size_t index, size_t total) { size_t len, size_t index, size_t total) {
if (index == 0) if (index == 0)
this->payload_buffer_.reserve(total); this->payload_buffer_.reserve(total);
@ -45,7 +45,7 @@ void MQTTClientComponent::setup() {
this->mqtt_client_.onDisconnect([this](AsyncMqttClientDisconnectReason reason) { this->mqtt_client_.onDisconnect([this](AsyncMqttClientDisconnectReason reason) {
this->state_ = MQTT_CLIENT_DISCONNECTED; this->state_ = MQTT_CLIENT_DISCONNECTED;
this->disconnect_reason_ = reason; this->disconnect_reason_ = reason;
}); });*/
#ifdef USE_LOGGER #ifdef USE_LOGGER
if (this->is_log_message_enabled() && logger::global_logger != nullptr) { if (this->is_log_message_enabled() && logger::global_logger != nullptr) {
logger::global_logger->add_on_log_callback([this](int level, const char *tag, const char *message) { logger::global_logger->add_on_log_callback([this](int level, const char *tag, const char *message) {
@ -58,12 +58,11 @@ void MQTTClientComponent::setup() {
#endif #endif
this->last_connected_ = millis(); this->last_connected_ = millis();
this->start_dnslookup_(); this->start_connect_();
} }
void MQTTClientComponent::dump_config() { void MQTTClientComponent::dump_config() {
ESP_LOGCONFIG(TAG, "MQTT:"); ESP_LOGCONFIG(TAG, "MQTT:");
ESP_LOGCONFIG(TAG, " Server Address: %s:%u (%s)", this->credentials_.address.c_str(), this->credentials_.port, ESP_LOGCONFIG(TAG, " Server Address: %s:%u", this->credentials_.address.c_str(), this->credentials_.port);
this->ip_.str().c_str());
ESP_LOGCONFIG(TAG, " Username: " LOG_SECRET("'%s'"), this->credentials_.username.c_str()); ESP_LOGCONFIG(TAG, " Username: " LOG_SECRET("'%s'"), this->credentials_.username.c_str());
ESP_LOGCONFIG(TAG, " Client ID: " LOG_SECRET("'%s'"), this->credentials_.client_id.c_str()); ESP_LOGCONFIG(TAG, " Client ID: " LOG_SECRET("'%s'"), this->credentials_.client_id.c_str());
if (!this->discovery_info_.prefix.empty()) { if (!this->discovery_info_.prefix.empty()) {
@ -80,131 +79,68 @@ void MQTTClientComponent::dump_config() {
} }
bool MQTTClientComponent::can_proceed() { return this->is_connected(); } bool MQTTClientComponent::can_proceed() { return this->is_connected(); }
void MQTTClientComponent::start_dnslookup_() {
void MQTTClientComponent::start_connect_() {
if (!network::is_connected())
return;
for (auto &subscription : this->subscriptions_) { for (auto &subscription : this->subscriptions_) {
subscription.subscribed = false; subscription.subscribed = false;
subscription.resubscribe_timeout = 0; subscription.resubscribe_timeout = 0;
} }
this->status_set_warning(); this->status_set_warning();
this->dns_resolve_error_ = false;
this->dns_resolved_ = false;
ip_addr_t addr;
#ifdef USE_ESP32
err_t err = dns_gethostbyname_addrtype(this->credentials_.address.c_str(), &addr,
MQTTClientComponent::dns_found_callback, this, LWIP_DNS_ADDRTYPE_IPV4);
#endif
#ifdef USE_ESP8266
err_t err = dns_gethostbyname(this->credentials_.address.c_str(), &addr,
esphome::mqtt::MQTTClientComponent::dns_found_callback, this);
#endif
switch (err) {
case ERR_OK: {
// Got IP immediately
this->dns_resolved_ = true;
#ifdef USE_ESP32
this->ip_ = addr.u_addr.ip4.addr;
#endif
#ifdef USE_ESP8266
this->ip_ = addr.addr;
#endif
this->start_connect_();
return;
}
case ERR_INPROGRESS: {
// wait for callback
ESP_LOGD(TAG, "Resolving MQTT broker IP address...");
break;
}
default:
case ERR_ARG: {
// error
#if defined(USE_ESP8266)
ESP_LOGW(TAG, "Error resolving MQTT broker IP address: %ld", err);
#else
ESP_LOGW(TAG, "Error resolving MQTT broker IP address: %d", err);
#endif
break;
}
}
this->state_ = MQTT_CLIENT_RESOLVING_ADDRESS;
this->connect_begin_ = millis();
}
void MQTTClientComponent::check_dnslookup_() {
if (!this->dns_resolved_ && millis() - this->connect_begin_ > 20000) {
this->dns_resolve_error_ = true;
}
if (this->dns_resolve_error_) {
ESP_LOGW(TAG, "Couldn't resolve IP address for '%s'!", this->credentials_.address.c_str());
this->state_ = MQTT_CLIENT_DISCONNECTED;
return;
}
if (!this->dns_resolved_) {
return;
}
ESP_LOGD(TAG, "Resolved broker IP address to %s", this->ip_.str().c_str());
this->start_connect_();
}
#if defined(USE_ESP8266) && LWIP_VERSION_MAJOR == 1
void MQTTClientComponent::dns_found_callback(const char *name, ip_addr_t *ipaddr, void *callback_arg) {
#else
void MQTTClientComponent::dns_found_callback(const char *name, const ip_addr_t *ipaddr, void *callback_arg) {
#endif
auto *a_this = (MQTTClientComponent *) callback_arg;
if (ipaddr == nullptr) {
a_this->dns_resolve_error_ = true;
} else {
#ifdef USE_ESP32
a_this->ip_ = ipaddr->u_addr.ip4.addr;
#endif
#ifdef USE_ESP8266
a_this->ip_ = ipaddr->addr;
#endif
a_this->dns_resolved_ = true;
}
}
void MQTTClientComponent::start_connect_() {
if (!network::is_connected())
return;
ESP_LOGI(TAG, "Connecting to MQTT..."); ESP_LOGI(TAG, "Connecting to MQTT...");
// Force disconnect first
this->mqtt_client_.disconnect(true);
this->mqtt_client_.setClientId(this->credentials_.client_id.c_str()); conn_params_.host = credentials_.address;
const char *username = nullptr; conn_params_.port = credentials_.port;
if (!this->credentials_.username.empty()) conn_params_.client_id = credentials_.client_id;
username = this->credentials_.username.c_str();
const char *password = nullptr;
if (!this->credentials_.password.empty())
password = this->credentials_.password.c_str();
this->mqtt_client_.setCredentials(username, password); if (!credentials_.username.empty())
conn_params_.username = credentials_.username;
this->mqtt_client_.setServer((uint32_t) this->ip_, this->credentials_.port); else
if (!this->last_will_.topic.empty()) { conn_params_.username.reset();
this->mqtt_client_.setWill(this->last_will_.topic.c_str(), this->last_will_.qos, this->last_will_.retain, if (!credentials_.password.empty()) {
this->last_will_.payload.c_str(), this->last_will_.payload.length()); std::vector<uint8_t> pwd{credentials_.password.begin(), credentials_.password.end()};
conn_params_.password = pwd;
} else {
conn_params_.password.reset();
}
if (!last_will_.topic.empty()) {
conn_params_.will_topic = last_will_.topic;
std::vector<uint8_t> msg{last_will_.payload.begin(), last_will_.payload.end()};
conn_params_.will_message = msg;
conn_params_.will_retain = last_will_.retain;
conn_params_.will_qos = static_cast<QOSLevel>(last_will_.qos);
} else {
conn_params_.will_topic = "";
conn_params_.will_message.clear();
conn_params_.will_retain = false;
conn_params_.will_qos = QOSLevel::QOS0;
}
conn_ = make_unique<MQTTConnection>();
ErrorCode ec = conn_->init(&conn_params_, &sess_);
if (ec != ErrorCode::OK) {
ESP_LOGW(TAG, "connection init failed: %d", (int) ec);
return;
} }
this->mqtt_client_.connect();
this->state_ = MQTT_CLIENT_CONNECTING; this->state_ = MQTT_CLIENT_CONNECTING;
this->connect_begin_ = millis(); this->connect_begin_ = millis();
} }
bool MQTTClientComponent::is_connected() { bool MQTTClientComponent::is_connected() {
return this->state_ == MQTT_CLIENT_CONNECTED && this->mqtt_client_.connected(); return this->state_ == MQTT_CLIENT_CONNECTED && this->conn_->is_connected();
} }
void MQTTClientComponent::check_connected() { void MQTTClientComponent::check_connected() {
if (!this->mqtt_client_.connected()) { if (conn_ && !conn_->is_connected()) {
if (millis() - this->connect_begin_ > 60000) { ErrorCode ec = conn_->loop();
this->state_ = MQTT_CLIENT_DISCONNECTED; if (ec != ErrorCode::OK) {
this->start_dnslookup_(); ESP_LOGW(TAG, "check connected loop failed: %d", (int) ec);
state_ = MQTT_CLIENT_DISCONNECTED;
} }
return; return;
} }
@ -223,64 +159,25 @@ void MQTTClientComponent::check_connected() {
} }
void MQTTClientComponent::loop() { void MQTTClientComponent::loop() {
if (this->disconnect_reason_.has_value()) {
const LogString *reason_s;
switch (*this->disconnect_reason_) {
case AsyncMqttClientDisconnectReason::TCP_DISCONNECTED:
reason_s = LOG_STR("TCP disconnected");
break;
case AsyncMqttClientDisconnectReason::MQTT_UNACCEPTABLE_PROTOCOL_VERSION:
reason_s = LOG_STR("Unacceptable Protocol Version");
break;
case AsyncMqttClientDisconnectReason::MQTT_IDENTIFIER_REJECTED:
reason_s = LOG_STR("Identifier Rejected");
break;
case AsyncMqttClientDisconnectReason::MQTT_SERVER_UNAVAILABLE:
reason_s = LOG_STR("Server Unavailable");
break;
case AsyncMqttClientDisconnectReason::MQTT_MALFORMED_CREDENTIALS:
reason_s = LOG_STR("Malformed Credentials");
break;
case AsyncMqttClientDisconnectReason::MQTT_NOT_AUTHORIZED:
reason_s = LOG_STR("Not Authorized");
break;
case AsyncMqttClientDisconnectReason::ESP8266_NOT_ENOUGH_SPACE:
reason_s = LOG_STR("Not Enough Space");
break;
case AsyncMqttClientDisconnectReason::TLS_BAD_FINGERPRINT:
reason_s = LOG_STR("TLS Bad Fingerprint");
break;
default:
reason_s = LOG_STR("Unknown");
break;
}
if (!network::is_connected()) {
reason_s = LOG_STR("WiFi disconnected");
}
ESP_LOGW(TAG, "MQTT Disconnected: %s.", LOG_STR_ARG(reason_s));
this->disconnect_reason_.reset();
}
const uint32_t now = millis(); const uint32_t now = millis();
switch (this->state_) { switch (this->state_) {
case MQTT_CLIENT_DISCONNECTED: case MQTT_CLIENT_DISCONNECTED:
if (now - this->connect_begin_ > 5000) {
this->start_dnslookup_();
}
break;
case MQTT_CLIENT_RESOLVING_ADDRESS:
this->check_dnslookup_();
break; break;
case MQTT_CLIENT_CONNECTING: case MQTT_CLIENT_CONNECTING:
this->check_connected(); this->check_connected();
break; break;
case MQTT_CLIENT_CONNECTED: case MQTT_CLIENT_CONNECTED:
if (!this->mqtt_client_.connected()) { if (!this->conn_->is_connected()) {
this->state_ = MQTT_CLIENT_DISCONNECTED; this->state_ = MQTT_CLIENT_DISCONNECTED;
ESP_LOGW(TAG, "Lost MQTT Client connection!"); ESP_LOGW(TAG, "Lost MQTT Client connection!");
this->start_dnslookup_(); this->start_connect_();
} else { } else {
ErrorCode ec = conn_->loop();
if (ec != ErrorCode::OK) {
ESP_LOGW(TAG, "loop loop failed");
}
if (!this->birth_message_.topic.empty() && !this->sent_birth_message_) { if (!this->birth_message_.topic.empty() && !this->sent_birth_message_) {
this->sent_birth_message_ = this->publish(this->birth_message_); this->sent_birth_message_ = this->publish(this->birth_message_);
} }
@ -303,17 +200,18 @@ bool MQTTClientComponent::subscribe_(const char *topic, uint8_t qos) {
if (!this->is_connected()) if (!this->is_connected())
return false; return false;
uint16_t ret = this->mqtt_client_.subscribe(topic, qos); Subscription sub{};
yield(); sub.topic_filter = topic;
sub.requested_qos = static_cast<QOSLevel>(qos);
ErrorCode ec = this->conn_->subscribe({sub});
if (ret != 0) { if (ec != ErrorCode::OK) {
ESP_LOGV(TAG, "subscribe(topic='%s')", topic);
} else {
delay(5);
ESP_LOGV(TAG, "Subscribe failed for topic='%s'. Will retry later.", topic); ESP_LOGV(TAG, "Subscribe failed for topic='%s'. Will retry later.", topic);
this->status_momentary_warning("subscribe", 1000); this->status_momentary_warning("subscribe", 1000);
} }
return ret != 0;
ESP_LOGV(TAG, "subscribe(topic='%s')", topic);
return ec == ErrorCode::OK;
} }
void MQTTClientComponent::resubscribe_subscription_(MQTTSubscription *sub) { void MQTTClientComponent::resubscribe_subscription_(MQTTSubscription *sub) {
if (sub->subscribed) if (sub->subscribed)
@ -361,15 +259,12 @@ void MQTTClientComponent::subscribe_json(const std::string &topic, const mqtt_js
} }
void MQTTClientComponent::unsubscribe(const std::string &topic) { void MQTTClientComponent::unsubscribe(const std::string &topic) {
uint16_t ret = this->mqtt_client_.unsubscribe(topic.c_str()); ErrorCode ec = this->conn_->unsubscribe({topic});
yield(); if (ec != ErrorCode::OK) {
if (ret != 0) {
ESP_LOGV(TAG, "unsubscribe(topic='%s')", topic.c_str());
} else {
delay(5);
ESP_LOGV(TAG, "Unsubscribe failed for topic='%s'.", topic.c_str()); ESP_LOGV(TAG, "Unsubscribe failed for topic='%s'.", topic.c_str());
this->status_momentary_warning("unsubscribe", 1000); this->status_momentary_warning("unsubscribe", 1000);
} }
ESP_LOGV(TAG, "unsubscribe(topic='%s')", topic.c_str());
auto it = subscriptions_.begin(); auto it = subscriptions_.begin();
while (it != subscriptions_.end()) { while (it != subscriptions_.end()) {
@ -393,24 +288,22 @@ bool MQTTClientComponent::publish(const std::string &topic, const char *payload,
return false; return false;
} }
bool logging_topic = topic == this->log_message_.topic; bool logging_topic = topic == this->log_message_.topic;
uint16_t ret = this->mqtt_client_.publish(topic.c_str(), qos, retain, payload, payload_length); std::vector<uint8_t> msg;
delay(0); for (size_t i = 0; i < payload_length; i++) {
if (ret == 0 && !logging_topic && this->is_connected()) { msg.push_back(static_cast<uint8_t>(payload[i]));
delay(0);
ret = this->mqtt_client_.publish(topic.c_str(), qos, retain, payload, payload_length);
delay(0);
} }
ErrorCode ec = this->conn_->publish(topic, std::move(msg), retain, static_cast<QOSLevel>(qos));
if (!logging_topic) { if (!logging_topic) {
if (ret != 0) { if (ec != ErrorCode::OK) {
ESP_LOGV(TAG, "Publish(topic='%s' payload='%s' retain=%d)", topic.c_str(), payload, retain);
} else {
ESP_LOGV(TAG, "Publish failed for topic='%s' (len=%u). will retry later..", topic.c_str(), ESP_LOGV(TAG, "Publish failed for topic='%s' (len=%u). will retry later..", topic.c_str(),
payload_length); // NOLINT payload_length); // NOLINT
this->status_momentary_warning("publish", 1000); this->status_momentary_warning("publish", 1000);
} else {
ESP_LOGV(TAG, "Publish(topic='%s' payload='%s' retain=%d)", topic.c_str(), payload, retain);
} }
} }
return ret != 0; return ec == ErrorCode::OK;
} }
bool MQTTClientComponent::publish(const MQTTMessage &message) { bool MQTTClientComponent::publish(const MQTTMessage &message) {
@ -500,7 +393,7 @@ bool MQTTClientComponent::is_log_message_enabled() const { return !this->log_mes
void MQTTClientComponent::set_reboot_timeout(uint32_t reboot_timeout) { this->reboot_timeout_ = reboot_timeout; } void MQTTClientComponent::set_reboot_timeout(uint32_t reboot_timeout) { this->reboot_timeout_ = reboot_timeout; }
void MQTTClientComponent::register_mqtt_component(MQTTComponent *component) { this->children_.push_back(component); } void MQTTClientComponent::register_mqtt_component(MQTTComponent *component) { this->children_.push_back(component); }
void MQTTClientComponent::set_log_level(int level) { this->log_level_ = level; } void MQTTClientComponent::set_log_level(int level) { this->log_level_ = level; }
void MQTTClientComponent::set_keep_alive(uint16_t keep_alive_s) { this->mqtt_client_.setKeepAlive(keep_alive_s); } void MQTTClientComponent::set_keep_alive(uint16_t keep_alive_s) { conn_params_.keep_alive = keep_alive_s; }
void MQTTClientComponent::set_log_message_template(MQTTMessage &&message) { this->log_message_ = std::move(message); } void MQTTClientComponent::set_log_message_template(MQTTMessage &&message) { this->log_message_ = std::move(message); }
const MQTTDiscoveryInfo &MQTTClientComponent::get_discovery_info() const { return this->discovery_info_; } const MQTTDiscoveryInfo &MQTTClientComponent::get_discovery_info() const { return this->discovery_info_; }
void MQTTClientComponent::set_topic_prefix(std::string topic_prefix) { this->topic_prefix_ = std::move(topic_prefix); } void MQTTClientComponent::set_topic_prefix(std::string topic_prefix) { this->topic_prefix_ = std::move(topic_prefix); }
@ -556,7 +449,7 @@ void MQTTClientComponent::on_shutdown() {
this->publish(this->shutdown_message_); this->publish(this->shutdown_message_);
yield(); yield();
} }
this->mqtt_client_.disconnect(true); // this->mqtt_client_.disconnect(true);
} }
#if ASYNC_TCP_SSL_ENABLED #if ASYNC_TCP_SSL_ENABLED

View file

@ -9,8 +9,7 @@
#include "esphome/core/log.h" #include "esphome/core/log.h"
#include "esphome/components/json/json_util.h" #include "esphome/components/json/json_util.h"
#include "esphome/components/network/ip_address.h" #include "esphome/components/network/ip_address.h"
#include <AsyncMqttClient.h> #include "mqtt_backend.h"
#include "lwip/ip_addr.h"
namespace esphome { namespace esphome {
namespace mqtt { namespace mqtt {
@ -74,7 +73,6 @@ struct MQTTDiscoveryInfo {
enum MQTTClientState { enum MQTTClientState {
MQTT_CLIENT_DISCONNECTED = 0, MQTT_CLIENT_DISCONNECTED = 0,
MQTT_CLIENT_RESOLVING_ADDRESS,
MQTT_CLIENT_CONNECTING, MQTT_CLIENT_CONNECTING,
MQTT_CLIENT_CONNECTED, MQTT_CLIENT_CONNECTED,
}; };
@ -116,22 +114,6 @@ class MQTTClientComponent : public Component {
void disable_discovery(); void disable_discovery();
bool is_discovery_enabled() const; bool is_discovery_enabled() const;
#if ASYNC_TCP_SSL_ENABLED
/** Add a SSL fingerprint to use for TCP SSL connections to the MQTT broker.
*
* To use this feature you first have to globally enable the `ASYNC_TCP_SSL_ENABLED` define flag.
* This function can be called multiple times and any certificate that matches any of the provided fingerprints
* will match. Calling this method will also automatically disable all non-ssl connections.
*
* @warning This is *not* secure and *not* how SSL is usually done. You'll have to add
* a separate fingerprint for every certificate you use. Additionally, the hashing
* algorithm used here due to the constraints of the MCU, SHA1, is known to be insecure.
*
* @param fingerprint The SSL fingerprint as a 20 value long std::array.
*/
void add_ssl_fingerprint(const std::array<uint8_t, SHA1_SIZE> &fingerprint);
#endif
const Availability &get_availability(); const Availability &get_availability();
/** Set the topic prefix that will be prepended to all topics together with "/". This will, in most cases, /** Set the topic prefix that will be prepended to all topics together with "/". This will, in most cases,
@ -237,13 +219,6 @@ class MQTTClientComponent : public Component {
protected: protected:
/// Reconnect to the MQTT broker if not already connected. /// Reconnect to the MQTT broker if not already connected.
void start_connect_(); void start_connect_();
void start_dnslookup_();
void check_dnslookup_();
#if defined(USE_ESP8266) && LWIP_VERSION_MAJOR == 1
static void dns_found_callback(const char *name, ip_addr_t *ipaddr, void *callback_arg);
#else
static void dns_found_callback(const char *name, const ip_addr_t *ipaddr, void *callback_arg);
#endif
/// Re-calculate the availability property. /// Re-calculate the availability property.
void recalculate_availability_(); void recalculate_availability_();
@ -272,20 +247,18 @@ class MQTTClientComponent : public Component {
}; };
std::string topic_prefix_{}; std::string topic_prefix_{};
MQTTMessage log_message_; MQTTMessage log_message_;
std::string payload_buffer_;
int log_level_{ESPHOME_LOG_LEVEL}; int log_level_{ESPHOME_LOG_LEVEL};
std::vector<MQTTSubscription> subscriptions_; std::vector<MQTTSubscription> subscriptions_;
AsyncMqttClient mqtt_client_;
MQTTClientState state_{MQTT_CLIENT_DISCONNECTED}; MQTTClientState state_{MQTT_CLIENT_DISCONNECTED};
network::IPAddress ip_;
bool dns_resolved_{false};
bool dns_resolve_error_{false};
std::vector<MQTTComponent *> children_; std::vector<MQTTComponent *> children_;
uint32_t reboot_timeout_{300000}; uint32_t reboot_timeout_{300000};
uint32_t connect_begin_; uint32_t connect_begin_;
uint32_t last_connected_{0}; uint32_t last_connected_{0};
optional<AsyncMqttClientDisconnectReason> disconnect_reason_{};
std::unique_ptr<MQTTConnection> conn_;
ConnectParams conn_params_{};
MQTTSession sess_{};
}; };
extern MQTTClientComponent *global_mqtt_client; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) extern MQTTClientComponent *global_mqtt_client; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View file

@ -0,0 +1,135 @@
#include "packets.h"
#ifdef USE_MQTT
namespace esphome {
namespace mqtt {
namespace util {
ErrorCode encode_uint16(std::vector<uint8_t> &target, uint16_t value) {
target.push_back((value >> 8) & 0xFF);
target.push_back((value >> 0) & 0xFF);
return ErrorCode::OK;
}
ErrorCode decode_uint16(Parser *parser, uint16_t *value) {
if (parser->size_left() < 2)
return ErrorCode::MALFORMED_PACKET;
*value = 0;
*value |= static_cast<uint16_t>(parser->consume()) << 8;
*value |= static_cast<uint16_t>(parser->consume()) << 0;
return ErrorCode::OK;
}
ErrorCode encode_bytes(std::vector<uint8_t> &target, const std::vector<uint8_t> &value) {
if (value.size() > 65535)
return ErrorCode::VALUE_TOO_LONG;
encode_uint16(target, value.size());
target.insert(target.end(), value.begin(), value.end());
return ErrorCode::OK;
}
ErrorCode decode_bytes(Parser *parser, std::vector<uint8_t> *value) {
uint16_t len;
ErrorCode ec = decode_uint16(parser, &len);
if (ec != ErrorCode::OK)
return ec;
if (len > parser->size_left())
return ErrorCode::MALFORMED_PACKET;
value->clear();
value->reserve(len);
for (size_t i = 0; i < len; i++)
value->push_back(parser->consume());
return ErrorCode::OK;
}
ErrorCode encode_utf8(std::vector<uint8_t> &target, const std::string &value) {
if (value.size() > 65535)
return ErrorCode::VALUE_TOO_LONG;
encode_uint16(target, value.size());
for (char c : value)
target.push_back(static_cast<uint8_t>(c));
return ErrorCode::OK;
}
ErrorCode decode_utf8(Parser *parser, std::string *value) {
uint16_t len;
ErrorCode ec = decode_uint16(parser, &len);
if (ec != ErrorCode::OK)
return ec;
if (len > parser->size_left())
return ErrorCode::MALFORMED_PACKET;
value->clear();
value->reserve(len);
for (size_t i = 0; i < len; i++)
value->push_back(static_cast<char>(parser->consume()));
return ErrorCode::OK;
}
ErrorCode encode_varint(std::vector<uint8_t> &target, size_t value) {
do {
uint8_t encbyte = value % 0x80;
value >>= 7;
if (value > 0)
encbyte |= 0x80;
target.push_back(encbyte);
} while (value > 0);
return ErrorCode::OK;
}
ErrorCode encode_fixed_header(
std::vector<uint8_t> &target,
PacketType packet_type,
uint8_t flags,
size_t remaining_length
) {
uint8_t head = 0;
head |= static_cast<uint8_t>(packet_type) << 4;
head |= flags;
target.push_back(head);
return encode_varint(target, remaining_length);
}
ErrorCode encode_packet(
std::vector<uint8_t> &target,
PacketType packet_type,
uint8_t flags
) {
return encode_fixed_header(target, packet_type, flags, 0);
}
ErrorCode encode_packet(
std::vector<uint8_t> &target,
PacketType packet_type,
uint8_t flags,
const std::vector<uint8_t> &variable_header
) {
ErrorCode ec = encode_fixed_header(
target, packet_type, flags,
variable_header.size()
);
if (ec != ErrorCode::OK)
return ec;
target.insert(target.end(), variable_header.begin(), variable_header.end());
return ErrorCode::OK;
}
ErrorCode encode_packet(
std::vector<uint8_t> &target,
PacketType packet_type,
uint8_t flags,
const std::vector<uint8_t> &variable_header,
const std::vector<uint8_t> &payload
) {
ErrorCode ec = encode_fixed_header(
target, packet_type, flags,
variable_header.size() + payload.size()
);
if (ec != ErrorCode::OK)
return ec;
target.insert(target.end(), variable_header.begin(), variable_header.end());
target.insert(target.end(), payload.begin(), payload.end());
return ErrorCode::OK;
}
} // namespace util
} // namespace mqtt
} // namespace esphome
#endif // USE_MQTT

View file

@ -0,0 +1,653 @@
#pragma once
#include "esphome/core/defines.h"
#ifdef USE_MQTT
#include <string>
#include <vector>
#include <cstdint>
#include "esphome/core/optional.h"
namespace esphome {
namespace mqtt {
enum class ErrorCode {
OK = 0,
VALUE_TOO_LONG = 1,
BAD_FLAGS = 2,
MALFORMED_PACKET = 3,
BAD_STATE = 4,
RESOLVE_ERROR = 5,
SOCKET_ERROR = 6,
TIMEOUT = 7,
IN_PROGRESS = 8,
WOULD_BLOCK = 9,
CONNECTION_CLOSED = 10,
UNEXPECTED = 11,
PROTOCOL_ERROR = 12,
};
enum class PacketType : uint8_t {
CONNECT = 1,
CONNACK = 2,
PUBLISH = 3,
PUBACK = 4,
PUBREC = 5,
PUBREL = 6,
PUBCOMP = 7,
SUBSCRIBE = 8,
SUBACK = 9,
UNSUBSCRIBE = 10,
UNSUBACK = 11,
PINGREQ = 12,
PINGRESP = 13,
DISCONNECT = 14,
};
namespace util {
class Parser {
public:
Parser(const uint8_t *data, size_t len) : data_(data), len_(len) {}
size_t size_left() const {
return len_ - at_;
}
uint8_t consume() {
return data_[at_++];
}
void consume(size_t amount) {
at_ += amount;
}
private:
const uint8_t *data_;
size_t len_;
size_t at_ = 0;
};
ErrorCode encode_uint16(std::vector<uint8_t> &target, uint16_t value);
ErrorCode decode_uint16(Parser *parser, uint16_t *value);
ErrorCode encode_bytes(std::vector<uint8_t> &target, const std::vector<uint8_t> &value);
ErrorCode decode_bytes(Parser *parser, std::vector<uint8_t> *value);
ErrorCode encode_utf8(std::vector<uint8_t> &target, const std::string &value);
ErrorCode decode_utf8(Parser *parser, std::string *value);
ErrorCode encode_varint(std::vector<uint8_t> &target, size_t value);
ErrorCode encode_fixed_header(
std::vector<uint8_t> &target,
PacketType packet_type,
uint8_t flags,
size_t remaining_length
);
ErrorCode encode_packet(
std::vector<uint8_t> &target,
PacketType packet_type,
uint8_t flags
);
ErrorCode encode_packet(
std::vector<uint8_t> &target,
PacketType packet_type,
uint8_t flags,
const std::vector<uint8_t> &variable_header
);
ErrorCode encode_packet(
std::vector<uint8_t> &target,
PacketType packet_type,
uint8_t flags,
const std::vector<uint8_t> &variable_header,
const std::vector<uint8_t> &payload
);
} // namespace util
class MQTTPacket {
public:
virtual ErrorCode encode(std::vector<uint8_t> &target) const = 0;
virtual ErrorCode decode(uint8_t flags, util::Parser parser) = 0;
};
enum class QOSLevel : uint8_t {
QOS0 = 0,
QOS1 = 1,
QOS2 = 2,
};
class ConnectPacket : public MQTTPacket {
public:
// 3.1
std::string client_id;
optional<std::string> username;
optional<std::vector<uint8_t>> password;
std::string will_topic;
std::vector<uint8_t> will_message;
QOSLevel will_qos = QOSLevel::QOS0;
bool will_retain = false;
bool clean_session = true;
uint8_t protocol_level = 4;
uint16_t keep_alive;
ErrorCode encode(std::vector<uint8_t> &target) const final {
uint8_t connect_flags = 0;
if (username.has_value())
connect_flags |= 0x80;
if (password.has_value())
connect_flags |= 0x40;
if (will_retain)
connect_flags |= 0x20;
connect_flags |= static_cast<uint8_t>(will_qos) << 3;
if (!will_topic.empty())
connect_flags |= 0x04;
if (clean_session)
connect_flags |= 0x02;
std::vector<uint8_t> variable_header;
variable_header.push_back(0x00);
variable_header.push_back(0x04);
variable_header.push_back('M');
variable_header.push_back('Q');
variable_header.push_back('T');
variable_header.push_back('T');
variable_header.push_back(protocol_level);
variable_header.push_back(connect_flags);
ErrorCode ec = util::encode_uint16(variable_header, keep_alive);
if (ec != ErrorCode::OK)
return ec;
std::vector<uint8_t> payload;
ec = util::encode_utf8(payload, client_id);
if (ec != ErrorCode::OK)
return ec;
if (!will_topic.empty()) {
ec = util::encode_utf8(payload, will_topic);
if (ec != ErrorCode::OK)
return ec;
ec = util::encode_bytes(payload, will_message);
if (ec != ErrorCode::OK)
return ec;
}
if (username.has_value()) {
ec = util::encode_utf8(payload, *username);
if (ec != ErrorCode::OK)
return ec;
}
if (password.has_value()) {
ec = util::encode_bytes(payload, *password);
if (ec != ErrorCode::OK)
return ec;
}
return util::encode_packet(
target, PacketType::CONNECT, 0,
variable_header, payload
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 0)
return ErrorCode::BAD_FLAGS;
if (
parser.size_left() < 10
|| parser.consume() != '\x00'
|| parser.consume() != '\x04'
|| parser.consume() != 'M'
|| parser.consume() != 'Q'
|| parser.consume() != 'T'
|| parser.consume() != 'T'
)
return ErrorCode::MALFORMED_PACKET;
protocol_level = parser.consume();
uint8_t connect_flags = parser.consume();
bool username_flag = connect_flags & 0x80;
bool password_flag = connect_flags & 0x40;
will_retain = connect_flags & 0x20;
will_qos = static_cast<QOSLevel>((connect_flags >> 3) & 3);
bool will_flag = connect_flags & 0x04;
clean_session = connect_flags & 0x02;
if ((flags & 1) != 0)
return ErrorCode::MALFORMED_PACKET;
ErrorCode ec = util::decode_uint16(&parser, &keep_alive);
if (ec != ErrorCode::OK)
return ec;
ec = util::decode_utf8(&parser, &client_id);
if (ec != ErrorCode::OK)
return ec;
will_topic.clear();
will_message.clear();
if (will_flag) {
ec = util::decode_utf8(&parser, &will_topic);
if (ec != ErrorCode::OK)
return ec;
ec = util::decode_bytes(&parser, &will_message);
if (ec != ErrorCode::OK)
return ec;
}
username.reset();
if (username_flag) {
username = {""};
ec = util::decode_utf8(&parser, &(*username));
if (ec != ErrorCode::OK)
return ec;
}
password.reset();
if (password_flag) {
password = std::vector<uint8_t>{};
ec = util::decode_bytes(&parser, &(*password));
if (ec != ErrorCode::OK)
return ec;
}
return ErrorCode::OK;
}
};
enum class ConnectReturnCode : uint8_t {
ACCEPTED = 0,
UNACCEPTABLE_PROTOCOL_VERSION = 1,
IDENTIFIER_REJECTED = 2,
SERVER_UNAVAILABLE = 3,
BAD_USER_NAME_OR_PASSWORD = 4,
NOT_AUTHORIZED = 5,
};
class ConnackPacket : public MQTTPacket {
public:
// 3.2
bool session_present;
ConnectReturnCode connect_return_code;
ErrorCode encode(std::vector<uint8_t> &target) const final {
std::vector<uint8_t> variable_header;
variable_header.push_back(static_cast<uint8_t>(session_present));
variable_header.push_back(static_cast<uint8_t>(connect_return_code));
return util::encode_packet(
target, PacketType::CONNACK, 0,
variable_header
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 0)
return ErrorCode::BAD_FLAGS;
if (parser.size_left() != 2)
return ErrorCode::MALFORMED_PACKET;
session_present = parser.consume() & 1;
connect_return_code = static_cast<ConnectReturnCode>(parser.consume());
return ErrorCode::OK;
}
};
class PublishPacket : public MQTTPacket {
public:
// 3.3
std::string topic;
std::vector<uint8_t> message;
bool dup = false;
QOSLevel qos = QOSLevel::QOS0;
bool retain = false;
optional<uint16_t> packet_identifier;
ErrorCode encode(std::vector<uint8_t> &target) const final {
uint8_t flags = 0;
if (dup)
flags |= 0x08;
flags |= static_cast<uint8_t>(qos) << 1;
if (retain)
flags |= 0x01;
std::vector<uint8_t> variable_header;
ErrorCode ec = util::encode_utf8(variable_header, topic);
if (ec != ErrorCode::OK)
return ec;
if (packet_identifier.has_value()) {
ec = util::encode_uint16(variable_header, *packet_identifier);
if (ec != ErrorCode::OK)
return ec;
}
return util::encode_packet(
target, PacketType::PUBLISH, flags,
variable_header, message
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
dup = flags & 0x08;
qos = static_cast<QOSLevel>((flags >> 1) & 3);
retain = flags & 0x01;
ErrorCode ec = util::decode_utf8(&parser, &topic);
if (ec != ErrorCode::OK)
return ec;
if (qos == QOSLevel::QOS1 || qos == QOSLevel::QOS2) {
packet_identifier = 0;
ec = util::decode_uint16(&parser, &(*packet_identifier));
if (ec != ErrorCode::OK)
return ec;
} else {
packet_identifier.reset();
}
message.clear();
message.reserve(parser.size_left());
while (parser.size_left())
message.push_back(parser.consume());
return ErrorCode::OK;
}
};
class PubackPacket : public MQTTPacket {
public:
// 3.4
uint16_t packet_identifier;
ErrorCode encode(std::vector<uint8_t> &target) const final {
std::vector<uint8_t> variable_header;
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
if (ec != ErrorCode::OK)
return ec;
return util::encode_packet(
target, PacketType::PUBACK, 0,
variable_header
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 0)
return ErrorCode::BAD_FLAGS;
if (parser.size_left() != 2)
return ErrorCode::MALFORMED_PACKET;
return util::decode_uint16(&parser, &packet_identifier);
}
};
class PubrecPacket : public MQTTPacket {
public:
// 3.5
uint16_t packet_identifier;
ErrorCode encode(std::vector<uint8_t> &target) const final {
std::vector<uint8_t> variable_header;
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
if (ec != ErrorCode::OK)
return ec;
return util::encode_packet(
target, PacketType::PUBREC, 0,
variable_header
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 0)
return ErrorCode::BAD_FLAGS;
if (parser.size_left() != 2)
return ErrorCode::MALFORMED_PACKET;
return util::decode_uint16(&parser, &packet_identifier);
}
};
class PubrelPacket : public MQTTPacket {
public:
// 3.6
uint16_t packet_identifier;
ErrorCode encode(std::vector<uint8_t> &target) const final {
std::vector<uint8_t> variable_header;
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
if (ec != ErrorCode::OK)
return ec;
return util::encode_packet(
target, PacketType::PUBREL, 2,
variable_header
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 2)
return ErrorCode::BAD_FLAGS;
if (parser.size_left() != 2)
return ErrorCode::MALFORMED_PACKET;
return util::decode_uint16(&parser, &packet_identifier);
}
};
class PubcompPacket : public MQTTPacket {
public:
// 3.7
uint16_t packet_identifier;
ErrorCode encode(std::vector<uint8_t> &target) const final {
std::vector<uint8_t> variable_header;
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
if (ec != ErrorCode::OK)
return ec;
return util::encode_packet(
target, PacketType::PUBCOMP, 2,
variable_header
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 2)
return ErrorCode::BAD_FLAGS;
if (parser.size_left() != 2)
return ErrorCode::MALFORMED_PACKET;
return util::decode_uint16(&parser, &packet_identifier);
}
};
struct Subscription {
std::string topic_filter;
QOSLevel requested_qos = QOSLevel::QOS0;
};
class SubscribePacket : public MQTTPacket {
public:
// 3.8
uint16_t packet_identifier;
std::vector<Subscription> subscriptions;
ErrorCode encode(std::vector<uint8_t> &target) const final {
std::vector<uint8_t> variable_header;
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
if (ec != ErrorCode::OK)
return ec;
std::vector<uint8_t> payload;
for (const auto &sub : subscriptions) {
ec = util::encode_utf8(payload, sub.topic_filter);
if (ec != ErrorCode::OK)
return ec;
payload.push_back(static_cast<uint8_t>(sub.requested_qos));
}
return util::encode_packet(
target, PacketType::SUBSCRIBE, 2,
variable_header, payload
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 2)
return ErrorCode::BAD_FLAGS;
ErrorCode ec = util::decode_uint16(&parser, &packet_identifier);
if (ec != ErrorCode::OK)
return ec;
subscriptions.clear();
while (parser.size_left()) {
Subscription sub{};
ec = util::decode_utf8(&parser, &sub.topic_filter);
if (ec != ErrorCode::OK)
return ec;
if (parser.size_left() < 1)
return ErrorCode::MALFORMED_PACKET;
sub.requested_qos = static_cast<QOSLevel>(parser.consume());
subscriptions.push_back(sub);
}
return ErrorCode::OK;
}
};
enum class SubackReturnCode : uint8_t {
SUCCESS_MAX_QOS0 = 0x00,
SUCCESS_MAX_QOS1 = 0x01,
SUCCESS_MAX_QOS2 = 0x02,
FAILURE = 0x80,
};
class SubackPacket : public MQTTPacket {
public:
// 3.9
uint16_t packet_identifier;
std::vector<SubackReturnCode> return_codes;
ErrorCode encode(std::vector<uint8_t> &target) const final {
std::vector<uint8_t> variable_header;
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
std::vector<uint8_t> payload;
payload.reserve(return_codes.size());
for (SubackReturnCode rc : return_codes) {
payload.push_back(static_cast<uint8_t>(rc));
}
return util::encode_packet(
target, PacketType::SUBACK, 0,
variable_header, payload
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 2)
return ErrorCode::BAD_FLAGS;
ErrorCode ec = util::decode_uint16(&parser, &packet_identifier);
if (ec != ErrorCode::OK)
return ec;
return_codes.clear();
return_codes.reserve(parser.size_left());
while (parser.size_left()) {
return_codes.push_back(static_cast<SubackReturnCode>(parser.consume()));
}
return ErrorCode::OK;
}
};
class UnsubscribePacket : public MQTTPacket {
public:
// 3.10
uint16_t packet_identifier;
std::vector<std::string> topic_filters;
ErrorCode encode(std::vector<uint8_t> &target) const final {
std::vector<uint8_t> variable_header;
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
if (ec != ErrorCode::OK)
return ec;
std::vector<uint8_t> payload;
for (const auto &topic : topic_filters) {
ec = util::encode_utf8(payload, topic);
if (ec != ErrorCode::OK)
return ec;
}
return util::encode_packet(
target, PacketType::UNSUBACK, 2,
variable_header, payload
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 2)
return ErrorCode::BAD_FLAGS;
ErrorCode ec = util::decode_uint16(&parser, &packet_identifier);
if (ec != ErrorCode::OK)
return ec;
topic_filters.clear();
while (parser.size_left()) {
std::string topic;
ec = util::decode_utf8(&parser, &topic);
if (ec != ErrorCode::OK)
return ec;
topic_filters.push_back(topic);
}
return ErrorCode::OK;
}
};
class UnsubackPacket : public MQTTPacket {
public:
// 3.11
uint16_t packet_identifier;
ErrorCode encode(std::vector<uint8_t> &target) const final {
std::vector<uint8_t> variable_header;
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
if (ec != ErrorCode::OK)
return ec;
return util::encode_packet(
target, PacketType::UNSUBACK, 0,
variable_header
);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 0)
return ErrorCode::BAD_FLAGS;
if (parser.size_left() != 2)
return ErrorCode::MALFORMED_PACKET;
return util::decode_uint16(&parser, &packet_identifier);
}
};
class PingreqPacket : public MQTTPacket {
public:
// 3.12
ErrorCode encode(std::vector<uint8_t> &target) const final {
return util::encode_packet(target, PacketType::PINGREQ, 0);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 0)
return ErrorCode::BAD_FLAGS;
if (parser.size_left() != 0)
return ErrorCode::MALFORMED_PACKET;
return ErrorCode::OK;
}
};
class PingrespPacket : public MQTTPacket {
public:
// 3.13
ErrorCode encode(std::vector<uint8_t> &target) const final {
return util::encode_packet(target, PacketType::PINGRESP, 0);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 0)
return ErrorCode::BAD_FLAGS;
if (parser.size_left() != 0)
return ErrorCode::MALFORMED_PACKET;
return ErrorCode::OK;
}
};
class DisconnectPacket : public MQTTPacket {
public:
// 3.14
ErrorCode encode(std::vector<uint8_t> &target) const final {
return util::encode_packet(target, PacketType::DISCONNECT, 0);
}
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
if (flags != 0)
return ErrorCode::BAD_FLAGS;
if (parser.size_left() != 0)
return ErrorCode::MALFORMED_PACKET;
return ErrorCode::OK;
}
};
} // namespace mqtt
} // namespace esphome
#endif // USE_MQTT