From 792108686cb6b0a90b2ee7dbf6ddb24cb5487ba7 Mon Sep 17 00:00:00 2001 From: Martin <25747549+martgras@users.noreply.github.com> Date: Mon, 4 Apr 2022 01:07:20 +0200 Subject: [PATCH] Add mqtt for idf (#2930) Co-authored-by: Flaviu Tamas Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Co-authored-by: Oxan van Leeuwen --- esphome/components/mqtt/__init__.py | 37 ++++- esphome/components/mqtt/mqtt_backend.h | 69 ++++++++ .../components/mqtt/mqtt_backend_arduino.h | 74 +++++++++ esphome/components/mqtt/mqtt_backend_idf.cpp | 149 ++++++++++++++++++ esphome/components/mqtt/mqtt_backend_idf.h | 143 +++++++++++++++++ esphome/components/mqtt/mqtt_client.cpp | 117 +++++++------- esphome/components/mqtt/mqtt_client.h | 31 ++-- esphome/core/defines.h | 7 +- tests/test5.yaml | 13 ++ 9 files changed, 567 insertions(+), 73 deletions(-) create mode 100644 esphome/components/mqtt/mqtt_backend.h create mode 100644 esphome/components/mqtt/mqtt_backend_arduino.h create mode 100644 esphome/components/mqtt/mqtt_backend_idf.cpp create mode 100644 esphome/components/mqtt/mqtt_backend_idf.h diff --git a/esphome/components/mqtt/__init__.py b/esphome/components/mqtt/__init__.py index 901b77474d..b2548d6081 100644 --- a/esphome/components/mqtt/__init__.py +++ b/esphome/components/mqtt/__init__.py @@ -9,6 +9,7 @@ from esphome.const import ( CONF_AVAILABILITY, CONF_BIRTH_MESSAGE, CONF_BROKER, + CONF_CERTIFICATE_AUTHORITY, CONF_CLIENT_ID, CONF_COMMAND_TOPIC, CONF_COMMAND_RETAIN, @@ -42,9 +43,14 @@ from esphome.const import ( CONF_WILL_MESSAGE, ) from esphome.core import coroutine_with_priority, CORE +from esphome.components.esp32 import add_idf_sdkconfig_option DEPENDENCIES = ["network"] -AUTO_LOAD = ["json", "async_tcp"] + +AUTO_LOAD = ["json"] + +CONF_IDF_SEND_ASYNC = "idf_send_async" +CONF_SKIP_CERT_CN_CHECK = "skip_cert_cn_check" def validate_message_just_topic(value): @@ -163,6 +169,15 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_USERNAME, default=""): cv.string, cv.Optional(CONF_PASSWORD, default=""): cv.string, cv.Optional(CONF_CLIENT_ID): cv.string, + cv.SplitDefault(CONF_IDF_SEND_ASYNC, esp32_idf=False): cv.All( + cv.boolean, cv.only_with_esp_idf + ), + cv.Optional(CONF_CERTIFICATE_AUTHORITY): cv.All( + cv.string, cv.only_with_esp_idf + ), + cv.SplitDefault(CONF_SKIP_CERT_CN_CHECK, esp32_idf=False): cv.All( + cv.boolean, cv.only_with_esp_idf + ), cv.Optional(CONF_DISCOVERY, default=True): cv.Any( cv.boolean, cv.one_of("CLEAN", upper=True) ), @@ -217,7 +232,6 @@ CONFIG_SCHEMA = cv.All( } ), validate_config, - cv.only_with_arduino, ) @@ -238,9 +252,11 @@ def exp_mqtt_message(config): async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) + # Add required libraries for arduino + if CORE.using_arduino: + # https://github.com/OttoWinter/async-mqtt-client/blob/master/library.json + cg.add_library("ottowinter/AsyncMqttClient-esphome", "0.8.6") - # https://github.com/OttoWinter/async-mqtt-client/blob/master/library.json - cg.add_library("ottowinter/AsyncMqttClient-esphome", "0.8.6") cg.add_define("USE_MQTT") cg.add_global(mqtt_ns.using) @@ -321,6 +337,19 @@ async def to_code(config): cg.add(var.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT])) + # esp-idf only + if CONF_CERTIFICATE_AUTHORITY in config: + cg.add(var.set_ca_certificate(config[CONF_CERTIFICATE_AUTHORITY])) + cg.add(var.set_skip_cert_cn_check(config[CONF_SKIP_CERT_CN_CHECK])) + + # prevent error -0x428e + # See https://github.com/espressif/esp-idf/issues/139 + add_idf_sdkconfig_option("CONFIG_MBEDTLS_HARDWARE_MPI", False) + + if CONF_IDF_SEND_ASYNC in config and config[CONF_IDF_SEND_ASYNC]: + cg.add_define("USE_MQTT_IDF_ENQUEUE") + # end esp-idf + for conf in config.get(CONF_ON_MESSAGE, []): trig = cg.new_Pvariable(conf[CONF_TRIGGER_ID], conf[CONF_TOPIC]) cg.add(trig.set_qos(conf[CONF_QOS])) diff --git a/esphome/components/mqtt/mqtt_backend.h b/esphome/components/mqtt/mqtt_backend.h new file mode 100644 index 0000000000..d23cda578d --- /dev/null +++ b/esphome/components/mqtt/mqtt_backend.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include "esphome/components/network/ip_address.h" +#include "esphome/core/helpers.h" + +namespace esphome { +namespace mqtt { + +enum class MQTTClientDisconnectReason : int8_t { + TCP_DISCONNECTED = 0, + MQTT_UNACCEPTABLE_PROTOCOL_VERSION = 1, + MQTT_IDENTIFIER_REJECTED = 2, + MQTT_SERVER_UNAVAILABLE = 3, + MQTT_MALFORMED_CREDENTIALS = 4, + MQTT_NOT_AUTHORIZED = 5, + ESP8266_NOT_ENOUGH_SPACE = 6, + TLS_BAD_FINGERPRINT = 7 +}; + +/// internal struct for MQTT messages. +struct MQTTMessage { + std::string topic; + std::string payload; + uint8_t qos; ///< QoS. Only for last will testaments. + bool retain; +}; + +class MQTTBackend { + public: + using on_connect_callback_t = void(bool session_present); + using on_disconnect_callback_t = void(MQTTClientDisconnectReason reason); + using on_subscribe_callback_t = void(uint16_t packet_id, uint8_t qos); + using on_unsubscribe_callback_t = void(uint16_t packet_id); + using on_message_callback_t = void(const char *topic, const char *payload, size_t len, size_t index, size_t total); + using on_publish_user_callback_t = void(uint16_t packet_id); + + virtual void set_keep_alive(uint16_t keep_alive) = 0; + virtual void set_client_id(const char *client_id) = 0; + virtual void set_clean_session(bool clean_session) = 0; + virtual void set_credentials(const char *username, const char *password) = 0; + virtual void set_will(const char *topic, uint8_t qos, bool retain, const char *payload) = 0; + virtual void set_server(network::IPAddress ip, uint16_t port) = 0; + virtual void set_server(const char *host, uint16_t port) = 0; + virtual void set_on_connect(std::function &&callback) = 0; + virtual void set_on_disconnect(std::function &&callback) = 0; + virtual void set_on_subscribe(std::function &&callback) = 0; + virtual void set_on_unsubscribe(std::function &&callback) = 0; + virtual void set_on_message(std::function &&callback) = 0; + virtual void set_on_publish(std::function &&callback) = 0; + virtual bool connected() const = 0; + virtual void connect() = 0; + virtual void disconnect() = 0; + virtual bool subscribe(const char *topic, uint8_t qos) = 0; + virtual bool unsubscribe(const char *topic) = 0; + virtual bool publish(const char *topic, const char *payload, size_t length, uint8_t qos, bool retain) = 0; + + virtual bool publish(const MQTTMessage &message) { + return publish(message.topic.c_str(), message.payload.c_str(), message.payload.length(), message.qos, + message.retain); + } + + // called from MQTTClient::loop() + virtual void loop() {} +}; + +} // namespace mqtt +} // namespace esphome diff --git a/esphome/components/mqtt/mqtt_backend_arduino.h b/esphome/components/mqtt/mqtt_backend_arduino.h new file mode 100644 index 0000000000..6399ec88e0 --- /dev/null +++ b/esphome/components/mqtt/mqtt_backend_arduino.h @@ -0,0 +1,74 @@ +#pragma once + +#ifdef USE_ARDUINO + +#include "mqtt_backend.h" +#include + +namespace esphome { +namespace mqtt { + +class MQTTBackendArduino final : public MQTTBackend { + public: + void set_keep_alive(uint16_t keep_alive) final { mqtt_client_.setKeepAlive(keep_alive); } + void set_client_id(const char *client_id) final { mqtt_client_.setClientId(client_id); } + void set_clean_session(bool clean_session) final { mqtt_client_.setCleanSession(clean_session); } + void set_credentials(const char *username, const char *password) final { + mqtt_client_.setCredentials(username, password); + } + void set_will(const char *topic, uint8_t qos, bool retain, const char *payload) final { + mqtt_client_.setWill(topic, qos, retain, payload); + } + void set_server(network::IPAddress ip, uint16_t port) final { + mqtt_client_.setServer(IPAddress(static_cast(ip)), port); + } + void set_server(const char *host, uint16_t port) final { mqtt_client_.setServer(host, port); } +#if ASYNC_TCP_SSL_ENABLED + void set_secure(bool secure) { mqtt_client.setSecure(secure); } + void add_server_fingerprint(const uint8_t *fingerprint) { mqtt_client.addServerFingerprint(fingerprint); } +#endif + + void set_on_connect(std::function &&callback) final { + this->mqtt_client_.onConnect(std::move(callback)); + } + void set_on_disconnect(std::function &&callback) final { + auto async_callback = [callback](AsyncMqttClientDisconnectReason reason) { + // int based enum so casting isn't a problem + callback(static_cast(reason)); + }; + this->mqtt_client_.onDisconnect(std::move(async_callback)); + } + void set_on_subscribe(std::function &&callback) final { + this->mqtt_client_.onSubscribe(std::move(callback)); + } + void set_on_unsubscribe(std::function &&callback) final { + this->mqtt_client_.onUnsubscribe(std::move(callback)); + } + void set_on_message(std::function &&callback) final { + auto async_callback = [callback](const char *topic, const char *payload, + AsyncMqttClientMessageProperties async_properties, size_t len, size_t index, + size_t total) { callback(topic, payload, len, index, total); }; + mqtt_client_.onMessage(std::move(async_callback)); + } + void set_on_publish(std::function &&callback) final { + this->mqtt_client_.onPublish(std::move(callback)); + } + + bool connected() const final { return mqtt_client_.connected(); } + void connect() final { mqtt_client_.connect(); } + void disconnect() final { mqtt_client_.disconnect(true); } + bool subscribe(const char *topic, uint8_t qos) final { return mqtt_client_.subscribe(topic, qos) != 0; } + bool unsubscribe(const char *topic) final { return mqtt_client_.unsubscribe(topic) != 0; } + bool publish(const char *topic, const char *payload, size_t length, uint8_t qos, bool retain) final { + return mqtt_client_.publish(topic, qos, retain, payload, length, false, 0) != 0; + } + using MQTTBackend::publish; + + protected: + AsyncMqttClient mqtt_client_; +}; + +} // namespace mqtt +} // namespace esphome + +#endif // defined(USE_ARDUINO) diff --git a/esphome/components/mqtt/mqtt_backend_idf.cpp b/esphome/components/mqtt/mqtt_backend_idf.cpp new file mode 100644 index 0000000000..0726f72567 --- /dev/null +++ b/esphome/components/mqtt/mqtt_backend_idf.cpp @@ -0,0 +1,149 @@ +#ifdef USE_ESP_IDF + +#include +#include "mqtt_backend_idf.h" +#include "esphome/core/log.h" +#include "esphome/core/helpers.h" + +namespace esphome { +namespace mqtt { + +static const char *const TAG = "mqtt.idf"; + +bool MQTTBackendIDF::initialize_() { + mqtt_cfg_.user_context = (void *) this; + mqtt_cfg_.buffer_size = MQTT_BUFFER_SIZE; + + mqtt_cfg_.host = this->host_.c_str(); + mqtt_cfg_.port = this->port_; + mqtt_cfg_.keepalive = this->keep_alive_; + mqtt_cfg_.disable_clean_session = !this->clean_session_; + + if (!this->username_.empty()) { + mqtt_cfg_.username = this->username_.c_str(); + if (!this->password_.empty()) { + mqtt_cfg_.password = this->password_.c_str(); + } + } + + if (!this->lwt_topic_.empty()) { + mqtt_cfg_.lwt_topic = this->lwt_topic_.c_str(); + this->mqtt_cfg_.lwt_qos = this->lwt_qos_; + this->mqtt_cfg_.lwt_retain = this->lwt_retain_; + + if (!this->lwt_message_.empty()) { + mqtt_cfg_.lwt_msg = this->lwt_message_.c_str(); + mqtt_cfg_.lwt_msg_len = this->lwt_message_.size(); + } + } + + if (!this->client_id_.empty()) { + mqtt_cfg_.client_id = this->client_id_.c_str(); + } + if (ca_certificate_.has_value()) { + mqtt_cfg_.cert_pem = ca_certificate_.value().c_str(); + mqtt_cfg_.skip_cert_common_name_check = skip_cert_cn_check_; + mqtt_cfg_.transport = MQTT_TRANSPORT_OVER_SSL; + } else { + mqtt_cfg_.transport = MQTT_TRANSPORT_OVER_TCP; + } + auto *mqtt_client = esp_mqtt_client_init(&mqtt_cfg_); + if (mqtt_client) { + handler_.reset(mqtt_client); + is_initalized_ = true; + esp_mqtt_client_register_event(mqtt_client, MQTT_EVENT_ANY, mqtt_event_handler, this); + return true; + } else { + ESP_LOGE(TAG, "Failed to initialize IDF-MQTT"); + return false; + } +} + +void MQTTBackendIDF::loop() { + // process new events + // handle only 1 message per loop iteration + if (!mqtt_events_.empty()) { + auto &event = mqtt_events_.front(); + mqtt_event_handler_(event); + mqtt_events_.pop(); + } +} + +void MQTTBackendIDF::mqtt_event_handler_(const esp_mqtt_event_t &event) { + ESP_LOGV(TAG, "Event dispatched from event loop event_id=%d", event.event_id); + switch (event.event_id) { + case MQTT_EVENT_BEFORE_CONNECT: + ESP_LOGV(TAG, "MQTT_EVENT_BEFORE_CONNECT"); + break; + + case MQTT_EVENT_CONNECTED: + ESP_LOGV(TAG, "MQTT_EVENT_CONNECTED"); + // TODO session present check + this->is_connected_ = true; + this->on_connect_.call(!mqtt_cfg_.disable_clean_session); + break; + case MQTT_EVENT_DISCONNECTED: + ESP_LOGV(TAG, "MQTT_EVENT_DISCONNECTED"); + // TODO is there a way to get the disconnect reason? + this->is_connected_ = false; + this->on_disconnect_.call(MQTTClientDisconnectReason::TCP_DISCONNECTED); + break; + + case MQTT_EVENT_SUBSCRIBED: + ESP_LOGV(TAG, "MQTT_EVENT_SUBSCRIBED, msg_id=%d", event.msg_id); + // hardcode QoS to 0. QoS is not used in this context but required to mirror the AsyncMqtt interface + this->on_subscribe_.call((int) event.msg_id, 0); + break; + case MQTT_EVENT_UNSUBSCRIBED: + ESP_LOGV(TAG, "MQTT_EVENT_UNSUBSCRIBED, msg_id=%d", event.msg_id); + this->on_unsubscribe_.call((int) event.msg_id); + break; + case MQTT_EVENT_PUBLISHED: + ESP_LOGV(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event.msg_id); + this->on_publish_.call((int) event.msg_id); + break; + case MQTT_EVENT_DATA: { + static std::string topic; + if (event.topic) { + // not 0 terminated - create a string from it + topic = std::string(event.topic, event.topic_len); + } + ESP_LOGV(TAG, "MQTT_EVENT_DATA %s", topic.c_str()); + auto data_len = event.data_len; + if (data_len == 0) + data_len = strlen(event.data); + this->on_message_.call(event.topic ? const_cast(topic.c_str()) : nullptr, event.data, data_len, + event.current_data_offset, event.total_data_len); + } break; + case MQTT_EVENT_ERROR: + ESP_LOGE(TAG, "MQTT_EVENT_ERROR"); + if (event.error_handle->error_type == MQTT_ERROR_TYPE_TCP_TRANSPORT) { + ESP_LOGE(TAG, "Last error code reported from esp-tls: 0x%x", event.error_handle->esp_tls_last_esp_err); + ESP_LOGE(TAG, "Last tls stack error number: 0x%x", event.error_handle->esp_tls_stack_err); + ESP_LOGE(TAG, "Last captured errno : %d (%s)", event.error_handle->esp_transport_sock_errno, + strerror(event.error_handle->esp_transport_sock_errno)); + } else if (event.error_handle->error_type == MQTT_ERROR_TYPE_CONNECTION_REFUSED) { + ESP_LOGE(TAG, "Connection refused error: 0x%x", event.error_handle->connect_return_code); + } else { + ESP_LOGE(TAG, "Unknown error type: 0x%x", event.error_handle->error_type); + } + break; + default: + ESP_LOGV(TAG, "Other event id:%d", event.event_id); + break; + } +} + +/// static - Dispatch event to instance method +void MQTTBackendIDF::mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_t event_id, void *event_data) { + MQTTBackendIDF *instance = static_cast(handler_args); + // queue event to decouple processing + if (instance) { + auto event = *static_cast(event_data); + instance->mqtt_events_.push(event); + } +} + +} // namespace mqtt +} // namespace esphome +#endif // USE_ESP_IDF diff --git a/esphome/components/mqtt/mqtt_backend_idf.h b/esphome/components/mqtt/mqtt_backend_idf.h new file mode 100644 index 0000000000..77b5592d72 --- /dev/null +++ b/esphome/components/mqtt/mqtt_backend_idf.h @@ -0,0 +1,143 @@ +#pragma once + +#ifdef USE_ESP_IDF + +#include +#include +#include +#include "esphome/components/network/ip_address.h" +#include "esphome/core/helpers.h" +#include "mqtt_backend.h" + +namespace esphome { +namespace mqtt { + +class MQTTBackendIDF final : public MQTTBackend { + public: + static const size_t MQTT_BUFFER_SIZE = 4096; + + void set_keep_alive(uint16_t keep_alive) final { this->keep_alive_ = keep_alive; } + void set_client_id(const char *client_id) final { this->client_id_ = client_id; } + void set_clean_session(bool clean_session) final { this->clean_session_ = clean_session; } + + void set_credentials(const char *username, const char *password) final { + if (username) + this->username_ = username; + if (password) + this->password_ = password; + } + void set_will(const char *topic, uint8_t qos, bool retain, const char *payload) final { + if (topic) + this->lwt_topic_ = topic; + this->lwt_qos_ = qos; + if (payload) + this->lwt_message_ = payload; + this->lwt_retain_ = retain; + } + void set_server(network::IPAddress ip, uint16_t port) final { + this->host_ = ip.str(); + this->port_ = port; + } + void set_server(const char *host, uint16_t port) final { + this->host_ = host; + this->port_ = port; + } + void set_on_connect(std::function &&callback) final { + this->on_connect_.add(std::move(callback)); + } + void set_on_disconnect(std::function &&callback) final { + this->on_disconnect_.add(std::move(callback)); + } + void set_on_subscribe(std::function &&callback) final { + this->on_subscribe_.add(std::move(callback)); + } + void set_on_unsubscribe(std::function &&callback) final { + this->on_unsubscribe_.add(std::move(callback)); + } + void set_on_message(std::function &&callback) final { + this->on_message_.add(std::move(callback)); + } + void set_on_publish(std::function &&callback) final { + this->on_publish_.add(std::move(callback)); + } + bool connected() const final { return this->is_connected_; } + + void connect() final { + if (!is_initalized_) { + if (initialize_()) { + esp_mqtt_client_start(handler_.get()); + } + } + } + void disconnect() final { + if (is_initalized_) + esp_mqtt_client_disconnect(handler_.get()); + } + + bool subscribe(const char *topic, uint8_t qos) final { + return esp_mqtt_client_subscribe(handler_.get(), topic, qos) != -1; + } + bool unsubscribe(const char *topic) final { return esp_mqtt_client_unsubscribe(handler_.get(), topic) != -1; } + + bool publish(const char *topic, const char *payload, size_t length, uint8_t qos, bool retain) final { +#if defined(USE_MQTT_IDF_ENQUEUE) + // use the non-blocking version + // it can delay sending a couple of seconds but won't block + return esp_mqtt_client_enqueue(handler_.get(), topic, payload, length, qos, retain, true) != -1; +#else + // might block for several seconds, either due to network timeout (10s) + // or if publishing payloads longer than internal buffer (due to message fragmentation) + return esp_mqtt_client_publish(handler_.get(), topic, payload, length, qos, retain) != -1; +#endif + } + using MQTTBackend::publish; + + void loop() final; + + void set_ca_certificate(const std::string &cert) { ca_certificate_ = cert; } + void set_skip_cert_cn_check(bool skip_check) { skip_cert_cn_check_ = skip_check; } + + protected: + bool initialize_(); + void mqtt_event_handler_(const esp_mqtt_event_t &event); + static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_t event_id, void *event_data); + + struct MqttClientDeleter { + void operator()(esp_mqtt_client *client_handler) { esp_mqtt_client_destroy(client_handler); } + }; + using ClientHandler_ = std::unique_ptr; + ClientHandler_ handler_; + + bool is_connected_{false}; + bool is_initalized_{false}; + + esp_mqtt_client_config_t mqtt_cfg_{}; + + std::string host_; + uint16_t port_; + std::string username_; + std::string password_; + std::string lwt_topic_; + std::string lwt_message_; + uint8_t lwt_qos_; + bool lwt_retain_; + std::string client_id_; + uint16_t keep_alive_; + bool clean_session_; + optional ca_certificate_; + bool skip_cert_cn_check_{false}; + + // callbacks + CallbackManager on_connect_; + CallbackManager on_disconnect_; + CallbackManager on_subscribe_; + CallbackManager on_unsubscribe_; + CallbackManager on_message_; + CallbackManager on_publish_; + std::queue mqtt_events_; +}; + +} // namespace mqtt +} // namespace esphome + +#endif diff --git a/esphome/components/mqtt/mqtt_client.cpp b/esphome/components/mqtt/mqtt_client.cpp index 1fea0c80cc..3c6ce7cdfc 100644 --- a/esphome/components/mqtt/mqtt_client.cpp +++ b/esphome/components/mqtt/mqtt_client.cpp @@ -27,21 +27,21 @@ MQTTClientComponent::MQTTClientComponent() { // Connection void MQTTClientComponent::setup() { ESP_LOGCONFIG(TAG, "Setting up MQTT..."); - this->mqtt_client_.onMessage([this](char const *topic, char *payload, AsyncMqttClientMessageProperties properties, - size_t len, size_t index, size_t total) { - if (index == 0) - this->payload_buffer_.reserve(total); + this->mqtt_backend_.set_on_message( + [this](const char *topic, const char *payload, size_t len, size_t index, size_t total) { + if (index == 0) + this->payload_buffer_.reserve(total); - // append new payload, may contain incomplete MQTT message - this->payload_buffer_.append(payload, len); + // append new payload, may contain incomplete MQTT message + this->payload_buffer_.append(payload, len); - // MQTT fully received - if (len + index == total) { - this->on_message(topic, this->payload_buffer_); - this->payload_buffer_.clear(); - } - }); - this->mqtt_client_.onDisconnect([this](AsyncMqttClientDisconnectReason reason) { + // MQTT fully received + if (len + index == total) { + this->on_message(topic, this->payload_buffer_); + this->payload_buffer_.clear(); + } + }); + this->mqtt_backend_.set_on_disconnect([this](MQTTClientDisconnectReason reason) { this->state_ = MQTT_CLIENT_DISCONNECTED; this->disconnect_reason_ = reason; }); @@ -49,8 +49,10 @@ void MQTTClientComponent::setup() { if (this->is_log_message_enabled() && logger::global_logger != nullptr) { logger::global_logger->add_on_log_callback([this](int level, const char *tag, const char *message) { if (level <= this->log_level_ && this->is_connected()) { - this->publish(this->log_message_.topic, message, strlen(message), this->log_message_.qos, - this->log_message_.retain); + this->publish({.topic = this->log_message_.topic, + .payload = message, + .qos = this->log_message_.qos, + .retain = this->log_message_.retain}); } }); } @@ -173,9 +175,9 @@ void MQTTClientComponent::start_connect_() { ESP_LOGI(TAG, "Connecting to MQTT..."); // Force disconnect first - this->mqtt_client_.disconnect(true); + this->mqtt_backend_.disconnect(); - this->mqtt_client_.setClientId(this->credentials_.client_id.c_str()); + this->mqtt_backend_.set_client_id(this->credentials_.client_id.c_str()); const char *username = nullptr; if (!this->credentials_.username.empty()) username = this->credentials_.username.c_str(); @@ -183,24 +185,24 @@ void MQTTClientComponent::start_connect_() { if (!this->credentials_.password.empty()) password = this->credentials_.password.c_str(); - this->mqtt_client_.setCredentials(username, password); + this->mqtt_backend_.set_credentials(username, password); - this->mqtt_client_.setServer((uint32_t) this->ip_, this->credentials_.port); + this->mqtt_backend_.set_server((uint32_t) this->ip_, this->credentials_.port); if (!this->last_will_.topic.empty()) { - this->mqtt_client_.setWill(this->last_will_.topic.c_str(), this->last_will_.qos, this->last_will_.retain, - this->last_will_.payload.c_str(), this->last_will_.payload.length()); + this->mqtt_backend_.set_will(this->last_will_.topic.c_str(), this->last_will_.qos, this->last_will_.retain, + this->last_will_.payload.c_str()); } - this->mqtt_client_.connect(); + this->mqtt_backend_.connect(); this->state_ = MQTT_CLIENT_CONNECTING; this->connect_begin_ = millis(); } bool MQTTClientComponent::is_connected() { - return this->state_ == MQTT_CLIENT_CONNECTED && this->mqtt_client_.connected(); + return this->state_ == MQTT_CLIENT_CONNECTED && this->mqtt_backend_.connected(); } void MQTTClientComponent::check_connected() { - if (!this->mqtt_client_.connected()) { + if (!this->mqtt_backend_.connected()) { if (millis() - this->connect_begin_ > 60000) { this->state_ = MQTT_CLIENT_DISCONNECTED; this->start_dnslookup_(); @@ -222,31 +224,34 @@ void MQTTClientComponent::check_connected() { } void MQTTClientComponent::loop() { + // Call the backend loop first + mqtt_backend_.loop(); + if (this->disconnect_reason_.has_value()) { const LogString *reason_s; switch (*this->disconnect_reason_) { - case AsyncMqttClientDisconnectReason::TCP_DISCONNECTED: + case MQTTClientDisconnectReason::TCP_DISCONNECTED: reason_s = LOG_STR("TCP disconnected"); break; - case AsyncMqttClientDisconnectReason::MQTT_UNACCEPTABLE_PROTOCOL_VERSION: + case MQTTClientDisconnectReason::MQTT_UNACCEPTABLE_PROTOCOL_VERSION: reason_s = LOG_STR("Unacceptable Protocol Version"); break; - case AsyncMqttClientDisconnectReason::MQTT_IDENTIFIER_REJECTED: + case MQTTClientDisconnectReason::MQTT_IDENTIFIER_REJECTED: reason_s = LOG_STR("Identifier Rejected"); break; - case AsyncMqttClientDisconnectReason::MQTT_SERVER_UNAVAILABLE: + case MQTTClientDisconnectReason::MQTT_SERVER_UNAVAILABLE: reason_s = LOG_STR("Server Unavailable"); break; - case AsyncMqttClientDisconnectReason::MQTT_MALFORMED_CREDENTIALS: + case MQTTClientDisconnectReason::MQTT_MALFORMED_CREDENTIALS: reason_s = LOG_STR("Malformed Credentials"); break; - case AsyncMqttClientDisconnectReason::MQTT_NOT_AUTHORIZED: + case MQTTClientDisconnectReason::MQTT_NOT_AUTHORIZED: reason_s = LOG_STR("Not Authorized"); break; - case AsyncMqttClientDisconnectReason::ESP8266_NOT_ENOUGH_SPACE: + case MQTTClientDisconnectReason::ESP8266_NOT_ENOUGH_SPACE: reason_s = LOG_STR("Not Enough Space"); break; - case AsyncMqttClientDisconnectReason::TLS_BAD_FINGERPRINT: + case MQTTClientDisconnectReason::TLS_BAD_FINGERPRINT: reason_s = LOG_STR("TLS Bad Fingerprint"); break; default: @@ -275,7 +280,7 @@ void MQTTClientComponent::loop() { this->check_connected(); break; case MQTT_CLIENT_CONNECTED: - if (!this->mqtt_client_.connected()) { + if (!this->mqtt_backend_.connected()) { this->state_ = MQTT_CLIENT_DISCONNECTED; ESP_LOGW(TAG, "Lost MQTT Client connection!"); this->start_dnslookup_(); @@ -302,10 +307,10 @@ bool MQTTClientComponent::subscribe_(const char *topic, uint8_t qos) { if (!this->is_connected()) return false; - uint16_t ret = this->mqtt_client_.subscribe(topic, qos); + bool ret = this->mqtt_backend_.subscribe(topic, qos); yield(); - if (ret != 0) { + if (ret) { ESP_LOGV(TAG, "subscribe(topic='%s')", topic); } else { delay(5); @@ -360,9 +365,9 @@ void MQTTClientComponent::subscribe_json(const std::string &topic, const mqtt_js } void MQTTClientComponent::unsubscribe(const std::string &topic) { - uint16_t ret = this->mqtt_client_.unsubscribe(topic.c_str()); + bool ret = this->mqtt_backend_.unsubscribe(topic.c_str()); yield(); - if (ret != 0) { + if (ret) { ESP_LOGV(TAG, "unsubscribe(topic='%s')", topic.c_str()); } else { delay(5); @@ -387,34 +392,35 @@ bool MQTTClientComponent::publish(const std::string &topic, const std::string &p bool MQTTClientComponent::publish(const std::string &topic, const char *payload, size_t payload_length, uint8_t qos, bool retain) { + return publish({.topic = topic, .payload = payload, .qos = qos, .retain = retain}); +} + +bool MQTTClientComponent::publish(const MQTTMessage &message) { if (!this->is_connected()) { // critical components will re-transmit their messages return false; } - bool logging_topic = topic == this->log_message_.topic; - uint16_t ret = this->mqtt_client_.publish(topic.c_str(), qos, retain, payload, payload_length); + bool logging_topic = this->log_message_.topic == message.topic; + bool ret = this->mqtt_backend_.publish(message); delay(0); - if (ret == 0 && !logging_topic && this->is_connected()) { + if (!ret && !logging_topic && this->is_connected()) { delay(0); - ret = this->mqtt_client_.publish(topic.c_str(), qos, retain, payload, payload_length); + ret = this->mqtt_backend_.publish(message); delay(0); } if (!logging_topic) { - if (ret != 0) { - ESP_LOGV(TAG, "Publish(topic='%s' payload='%s' retain=%d)", topic.c_str(), payload, retain); + if (ret) { + ESP_LOGV(TAG, "Publish(topic='%s' payload='%s' retain=%d)", message.topic.c_str(), message.payload.c_str(), + message.retain); } else { - ESP_LOGV(TAG, "Publish failed for topic='%s' (len=%u). will retry later..", topic.c_str(), - payload_length); // NOLINT + ESP_LOGV(TAG, "Publish failed for topic='%s' (len=%u). will retry later..", message.topic.c_str(), + message.payload.length()); this->status_momentary_warning("publish", 1000); } } return ret != 0; } - -bool MQTTClientComponent::publish(const MQTTMessage &message) { - return this->publish(message.topic, message.payload, message.qos, message.retain); -} bool MQTTClientComponent::publish_json(const std::string &topic, const json::json_build_t &f, uint8_t qos, bool retain) { std::string message = json::build_json(f); @@ -499,10 +505,10 @@ bool MQTTClientComponent::is_log_message_enabled() const { return !this->log_mes void MQTTClientComponent::set_reboot_timeout(uint32_t reboot_timeout) { this->reboot_timeout_ = reboot_timeout; } void MQTTClientComponent::register_mqtt_component(MQTTComponent *component) { this->children_.push_back(component); } void MQTTClientComponent::set_log_level(int level) { this->log_level_ = level; } -void MQTTClientComponent::set_keep_alive(uint16_t keep_alive_s) { this->mqtt_client_.setKeepAlive(keep_alive_s); } +void MQTTClientComponent::set_keep_alive(uint16_t keep_alive_s) { this->mqtt_backend_.set_keep_alive(keep_alive_s); } void MQTTClientComponent::set_log_message_template(MQTTMessage &&message) { this->log_message_ = std::move(message); } const MQTTDiscoveryInfo &MQTTClientComponent::get_discovery_info() const { return this->discovery_info_; } -void MQTTClientComponent::set_topic_prefix(std::string topic_prefix) { this->topic_prefix_ = std::move(topic_prefix); } +void MQTTClientComponent::set_topic_prefix(const std::string &topic_prefix) { this->topic_prefix_ = topic_prefix; } const std::string &MQTTClientComponent::get_topic_prefix() const { return this->topic_prefix_; } void MQTTClientComponent::disable_birth_message() { this->birth_message_.topic = ""; @@ -549,7 +555,8 @@ void MQTTClientComponent::set_discovery_info(std::string &&prefix, MQTTDiscovery void MQTTClientComponent::disable_last_will() { this->last_will_.topic = ""; } void MQTTClientComponent::disable_discovery() { - this->discovery_info_ = MQTTDiscoveryInfo{.prefix = "", .retain = false}; + this->discovery_info_ = MQTTDiscoveryInfo{ + .prefix = "", .retain = false, .clean = false, .unique_id_generator = MQTT_LEGACY_UNIQUE_ID_GENERATOR}; } void MQTTClientComponent::on_shutdown() { if (!this->shutdown_message_.topic.empty()) { @@ -557,13 +564,13 @@ void MQTTClientComponent::on_shutdown() { this->publish(this->shutdown_message_); yield(); } - this->mqtt_client_.disconnect(true); + this->mqtt_backend_.disconnect(); } #if ASYNC_TCP_SSL_ENABLED void MQTTClientComponent::add_ssl_fingerprint(const std::array &fingerprint) { - this->mqtt_client_.setSecure(true); - this->mqtt_client_.addServerFingerprint(fingerprint.data()); + this->mqtt_backend_.setSecure(true); + this->mqtt_backend_.addServerFingerprint(fingerprint.data()); } #endif diff --git a/esphome/components/mqtt/mqtt_client.h b/esphome/components/mqtt/mqtt_client.h index 58a4fbe166..4880bbaa5b 100644 --- a/esphome/components/mqtt/mqtt_client.h +++ b/esphome/components/mqtt/mqtt_client.h @@ -9,7 +9,11 @@ #include "esphome/core/log.h" #include "esphome/components/json/json_util.h" #include "esphome/components/network/ip_address.h" -#include +#if defined(USE_ESP_IDF) +#include "mqtt_backend_idf.h" +#elif defined(USE_ARDUINO) +#include "mqtt_backend_arduino.h" +#endif #include "lwip/ip_addr.h" namespace esphome { @@ -22,14 +26,6 @@ namespace mqtt { using mqtt_callback_t = std::function; using mqtt_json_callback_t = std::function; -/// internal struct for MQTT messages. -struct MQTTMessage { - std::string topic; - std::string payload; - uint8_t qos; ///< QoS. Only for last will testaments. - bool retain; -}; - /// internal struct for MQTT subscriptions. struct MQTTSubscription { std::string topic; @@ -139,7 +135,10 @@ class MQTTClientComponent : public Component { */ void add_ssl_fingerprint(const std::array &fingerprint); #endif - +#ifdef USE_ESP_IDF + void set_ca_certificate(const char *cert) { this->mqtt_backend_.set_ca_certificate(cert); } + void set_skip_cert_cn_check(bool skip_check) { this->mqtt_backend_.set_skip_cert_cn_check(skip_check); } +#endif const Availability &get_availability(); /** Set the topic prefix that will be prepended to all topics together with "/". This will, in most cases, @@ -150,7 +149,7 @@ class MQTTClientComponent : public Component { * * @param topic_prefix The topic prefix. The last "/" is appended automatically. */ - void set_topic_prefix(std::string topic_prefix); + void set_topic_prefix(const std::string &topic_prefix); /// Get the topic prefix of this device, using default if necessary const std::string &get_topic_prefix() const; @@ -277,6 +276,7 @@ class MQTTClientComponent : public Component { .prefix = "homeassistant", .retain = true, .clean = false, + .unique_id_generator = MQTT_LEGACY_UNIQUE_ID_GENERATOR, }; std::string topic_prefix_{}; MQTTMessage log_message_; @@ -284,7 +284,12 @@ class MQTTClientComponent : public Component { int log_level_{ESPHOME_LOG_LEVEL}; std::vector subscriptions_; - AsyncMqttClient mqtt_client_; +#if defined(USE_ESP_IDF) + MQTTBackendIDF mqtt_backend_; +#elif defined(USE_ARDUINO) + MQTTBackendArduino mqtt_backend_; +#endif + MQTTClientState state_{MQTT_CLIENT_DISCONNECTED}; network::IPAddress ip_; bool dns_resolved_{false}; @@ -293,7 +298,7 @@ class MQTTClientComponent : public Component { uint32_t reboot_timeout_{300000}; uint32_t connect_begin_; uint32_t last_connected_{0}; - optional disconnect_reason_{}; + optional disconnect_reason_{}; }; extern MQTTClientComponent *global_mqtt_client; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/esphome/core/defines.h b/esphome/core/defines.h index b5c82338b3..f304f847a5 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -29,6 +29,7 @@ #define USE_LOCK #define USE_LOGGER #define USE_MDNS +#define USE_MQTT #define USE_NUMBER #define USE_OTA_PASSWORD #define USE_OTA_STATE_CALLBACK @@ -49,13 +50,17 @@ #define USE_CAPTIVE_PORTAL #define USE_JSON #define USE_NEXTION_TFT_UPLOAD -#define USE_MQTT #define USE_PROMETHEUS #define USE_WEBSERVER #define USE_WEBSERVER_PORT 80 // NOLINT #define USE_WIFI_WPA2_EAP #endif +// IDF-specific feature flags +#ifdef USE_ESP_IDF +#define USE_MQTT_IDF_ENQUEUE +#endif + // ESP32-specific feature flags #ifdef USE_ESP32 #define USE_ESP32_BLE_CLIENT diff --git a/tests/test5.yaml b/tests/test5.yaml index 9bfd395538..ee90cc1149 100644 --- a/tests/test5.yaml +++ b/tests/test5.yaml @@ -49,6 +49,19 @@ modbus_controller: address: 0x2 modbus_id: mod_bus1 +mqtt: + broker: test.mosquitto.org + port: 1883 + discovery: true + discovery_prefix: homeassistant + idf_send_async: false + on_message: + topic: testing/sensor/testing_sensor/state + qos: 0 + then: + - lambda: |- + ESP_LOGD("Mqtt Test","testing/sensor/testing_sensor/state=[%s]",x.c_str()); + binary_sensor: - platform: gpio pin: GPIO0