diff --git a/esphome/components/mqtt/__init__.py b/esphome/components/mqtt/__init__.py index b2548d6081..3f9275e231 100644 --- a/esphome/components/mqtt/__init__.py +++ b/esphome/components/mqtt/__init__.py @@ -24,6 +24,8 @@ from esphome.const import ( CONF_LOG_TOPIC, CONF_ON_JSON_MESSAGE, CONF_ON_MESSAGE, + CONF_ON_CONNECT, + CONF_ON_DISCONNECT, CONF_PASSWORD, CONF_PAYLOAD, CONF_PAYLOAD_AVAILABLE, @@ -90,6 +92,10 @@ MQTTMessageTrigger = mqtt_ns.class_( MQTTJsonMessageTrigger = mqtt_ns.class_( "MQTTJsonMessageTrigger", automation.Trigger.template(cg.JsonObjectConst) ) +MQTTConnectTrigger = mqtt_ns.class_("MQTTConnectTrigger", automation.Trigger.template()) +MQTTDisconnectTrigger = mqtt_ns.class_( + "MQTTDisconnectTrigger", automation.Trigger.template() +) MQTTComponent = mqtt_ns.class_("MQTTComponent", cg.Component) MQTTConnectedCondition = mqtt_ns.class_("MQTTConnectedCondition", Condition) @@ -212,6 +218,18 @@ CONFIG_SCHEMA = cv.All( cv.Optional( CONF_REBOOT_TIMEOUT, default="15min" ): cv.positive_time_period_milliseconds, + cv.Optional(CONF_ON_CONNECT): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(MQTTConnectTrigger), + } + ), + cv.Optional(CONF_ON_DISCONNECT): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id( + MQTTDisconnectTrigger + ), + } + ), cv.Optional(CONF_ON_MESSAGE): automation.validate_automation( { cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(MQTTMessageTrigger), @@ -362,6 +380,14 @@ async def to_code(config): trig = cg.new_Pvariable(conf[CONF_TRIGGER_ID], conf[CONF_TOPIC], conf[CONF_QOS]) await automation.build_automation(trig, [(cg.JsonObjectConst, "x")], conf) + for conf in config.get(CONF_ON_CONNECT, []): + trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) + await automation.build_automation(trigger, [], conf) + + for conf in config.get(CONF_ON_DISCONNECT, []): + trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) + await automation.build_automation(trigger, [], conf) + MQTT_PUBLISH_ACTION_SCHEMA = cv.Schema( { diff --git a/esphome/components/mqtt/mqtt_client.cpp b/esphome/components/mqtt/mqtt_client.cpp index 12a43dc232..157504fb41 100644 --- a/esphome/components/mqtt/mqtt_client.cpp +++ b/esphome/components/mqtt/mqtt_client.cpp @@ -572,6 +572,14 @@ void MQTTClientComponent::on_shutdown() { this->mqtt_backend_.disconnect(); } +void MQTTClientComponent::set_on_connect(mqtt_on_connect_callback_t &&callback) { + this->mqtt_backend_.set_on_connect(std::forward(callback)); +} + +void MQTTClientComponent::set_on_disconnect(mqtt_on_disconnect_callback_t &&callback) { + this->mqtt_backend_.set_on_disconnect(std::forward(callback)); +} + #if ASYNC_TCP_SSL_ENABLED void MQTTClientComponent::add_ssl_fingerprint(const std::array &fingerprint) { this->mqtt_backend_.setSecure(true); diff --git a/esphome/components/mqtt/mqtt_client.h b/esphome/components/mqtt/mqtt_client.h index 20b174a66f..0310655146 100644 --- a/esphome/components/mqtt/mqtt_client.h +++ b/esphome/components/mqtt/mqtt_client.h @@ -19,6 +19,11 @@ namespace esphome { namespace mqtt { +/** Callback for MQTT events. + */ +using mqtt_on_connect_callback_t = std::function; +using mqtt_on_disconnect_callback_t = std::function; + /** Callback for MQTT subscriptions. * * First parameter is the topic, the second one is the payload. @@ -240,6 +245,8 @@ class MQTTClientComponent : public Component { void set_username(const std::string &username) { this->credentials_.username = username; } void set_password(const std::string &password) { this->credentials_.password = password; } void set_client_id(const std::string &client_id) { this->credentials_.client_id = client_id; } + void set_on_connect(mqtt_on_connect_callback_t &&callback); + void set_on_disconnect(mqtt_on_disconnect_callback_t &&callback); protected: /// Reconnect to the MQTT broker if not already connected. @@ -328,6 +335,20 @@ class MQTTJsonMessageTrigger : public Trigger { } }; +class MQTTConnectTrigger : public Trigger<> { + public: + explicit MQTTConnectTrigger(MQTTClientComponent *&client) { + client->set_on_connect([this](bool session_present) { this->trigger(); }); + } +}; + +class MQTTDisconnectTrigger : public Trigger<> { + public: + explicit MQTTDisconnectTrigger(MQTTClientComponent *&client) { + client->set_on_disconnect([this](MQTTClientDisconnectReason reason) { this->trigger(); }); + } +}; + template class MQTTPublishAction : public Action { public: MQTTPublishAction(MQTTClientComponent *parent) : parent_(parent) {} diff --git a/tests/test1.yaml b/tests/test1.yaml index 1b8ed7e370..2a157e3513 100644 --- a/tests/test1.yaml +++ b/tests/test1.yaml @@ -168,6 +168,13 @@ mqtt: id: uart0 data: !lambda |- return {}; + on_connect: + - light.turn_on: ${roomname}_lights + - mqtt.publish: + topic: some/topic + payload: Hello + on_disconnect: + - light.turn_off: ${roomname}_lights i2c: sda: 21