From 623570a117857180816e9dbb2e426aadba56fbd1 Mon Sep 17 00:00:00 2001 From: Maurice Makaay Date: Sat, 10 Jul 2021 21:52:19 +0200 Subject: [PATCH] Add state callback to ota component (#1816) Co-authored-by: Maurice Makaay Co-authored-by: Guillermo Ruffino --- esphome/components/ota/__init__.py | 65 ++++++++++++++++++++++ esphome/components/ota/automation.h | 71 ++++++++++++++++++++++++ esphome/components/ota/ota_component.cpp | 20 ++++++- esphome/components/ota/ota_component.h | 11 ++++ tests/test1.yaml | 18 ++++++ 5 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 esphome/components/ota/automation.h diff --git a/esphome/components/ota/__init__.py b/esphome/components/ota/__init__.py index 7ee7ef47ca..75641ad399 100644 --- a/esphome/components/ota/__init__.py +++ b/esphome/components/ota/__init__.py @@ -1,6 +1,7 @@ from esphome.cpp_generator import RawExpression import esphome.codegen as cg import esphome.config_validation as cv +from esphome import automation from esphome.const import ( CONF_ID, CONF_NUM_ATTEMPTS, @@ -8,14 +9,29 @@ from esphome.const import ( CONF_PORT, CONF_REBOOT_TIMEOUT, CONF_SAFE_MODE, + CONF_TRIGGER_ID, ) from esphome.core import CORE, coroutine_with_priority CODEOWNERS = ["@esphome/core"] DEPENDENCIES = ["network"] +CONF_ON_STATE_CHANGE = "on_state_change" +CONF_ON_BEGIN = "on_begin" +CONF_ON_PROGRESS = "on_progress" +CONF_ON_END = "on_end" +CONF_ON_ERROR = "on_error" + ota_ns = cg.esphome_ns.namespace("ota") +OTAState = ota_ns.enum("OTAState") OTAComponent = ota_ns.class_("OTAComponent", cg.Component) +OTAStateChangeTrigger = ota_ns.class_( + "OTAStateChangeTrigger", automation.Trigger.template() +) +OTAStartTrigger = ota_ns.class_("OTAStartTrigger", automation.Trigger.template()) +OTAProgressTrigger = ota_ns.class_("OTAProgressTrigger", automation.Trigger.template()) +OTAEndTrigger = ota_ns.class_("OTAEndTrigger", automation.Trigger.template()) +OTAErrorTrigger = ota_ns.class_("OTAErrorTrigger", automation.Trigger.template()) CONFIG_SCHEMA = cv.Schema( { @@ -27,6 +43,31 @@ CONFIG_SCHEMA = cv.Schema( CONF_REBOOT_TIMEOUT, default="5min" ): cv.positive_time_period_milliseconds, cv.Optional(CONF_NUM_ATTEMPTS, default="10"): cv.positive_not_null_int, + cv.Optional(CONF_ON_STATE_CHANGE): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAStateChangeTrigger), + } + ), + cv.Optional(CONF_ON_BEGIN): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAStartTrigger), + } + ), + cv.Optional(CONF_ON_ERROR): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAErrorTrigger), + } + ), + cv.Optional(CONF_ON_PROGRESS): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAProgressTrigger), + } + ), + cv.Optional(CONF_ON_END): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAEndTrigger), + } + ), } ).extend(cv.COMPONENT_SCHEMA) @@ -49,3 +90,27 @@ async def to_code(config): cg.add_library("Update", None) elif CORE.is_esp32: cg.add_library("Hash", None) + + use_state_callback = False + for conf in config.get(CONF_ON_STATE_CHANGE, []): + trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) + await automation.build_automation(trigger, [(OTAState, "state")], conf) + use_state_callback = True + for conf in config.get(CONF_ON_BEGIN, []): + trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) + await automation.build_automation(trigger, [], conf) + use_state_callback = True + for conf in config.get(CONF_ON_PROGRESS, []): + trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) + await automation.build_automation(trigger, [(float, "x")], conf) + use_state_callback = True + for conf in config.get(CONF_ON_END, []): + trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) + await automation.build_automation(trigger, [], conf) + use_state_callback = True + for conf in config.get(CONF_ON_ERROR, []): + trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) + await automation.build_automation(trigger, [(int, "x")], conf) + use_state_callback = True + if use_state_callback: + cg.add_define("USE_OTA_STATE_CALLBACK") diff --git a/esphome/components/ota/automation.h b/esphome/components/ota/automation.h new file mode 100644 index 0000000000..6c8aca3705 --- /dev/null +++ b/esphome/components/ota/automation.h @@ -0,0 +1,71 @@ +#pragma once + +#include "esphome/core/defines.h" +#ifdef USE_OTA_STATE_CALLBACK + +#include "esphome/core/component.h" +#include "esphome/core/automation.h" +#include "esphome/components/ota/ota_component.h" + +namespace esphome { +namespace ota { + +class OTAStateChangeTrigger : public Trigger { + public: + explicit OTAStateChangeTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (!parent->is_failed()) { + return trigger(state); + } + }); + } +}; + +class OTAStartTrigger : public Trigger<> { + public: + explicit OTAStartTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (state == OTA_STARTED && !parent->is_failed()) { + trigger(); + } + }); + } +}; + +class OTAProgressTrigger : public Trigger { + public: + explicit OTAProgressTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (state == OTA_IN_PROGRESS && !parent->is_failed()) { + trigger(progress); + } + }); + } +}; + +class OTAEndTrigger : public Trigger<> { + public: + explicit OTAEndTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (state == OTA_COMPLETED && !parent->is_failed()) { + trigger(); + } + }); + } +}; + +class OTAErrorTrigger : public Trigger { + public: + explicit OTAErrorTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (state == OTA_ERROR && !parent->is_failed()) { + trigger(error); + } + }); + } +}; + +} // namespace ota +} // namespace esphome + +#endif // USE_OTA_STATE_CALLBACK diff --git a/esphome/components/ota/ota_component.cpp b/esphome/components/ota/ota_component.cpp index 5302b7bc24..71f8101704 100644 --- a/esphome/components/ota/ota_component.cpp +++ b/esphome/components/ota/ota_component.cpp @@ -1,7 +1,6 @@ #include "ota_component.h" #include "esphome/core/log.h" -#include "esphome/core/helpers.h" #include "esphome/core/application.h" #include "esphome/core/util.h" @@ -25,6 +24,7 @@ void OTAComponent::setup() { this->dump_config(); } + void OTAComponent::dump_config() { ESP_LOGCONFIG(TAG, "Over-The-Air Updates:"); ESP_LOGCONFIG(TAG, " Address: %s:%u", network_get_address().c_str(), this->port_); @@ -71,6 +71,9 @@ void OTAComponent::handle_() { ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_.remoteIP().toString().c_str()); this->status_set_warning(); +#ifdef USE_OTA_STATE_CALLBACK + this->state_callback_.call(OTA_STARTED, 0.0f, 0); +#endif if (!this->wait_receive_(buf, 5)) { ESP_LOGW(TAG, "Reading magic bytes failed!"); @@ -241,6 +244,9 @@ void OTAComponent::handle_() { last_progress = now; float percentage = (total * 100.0f) / ota_size; ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage); +#ifdef USE_OTA_STATE_CALLBACK + this->state_callback_.call(OTA_IN_PROGRESS, percentage, 0); +#endif // slow down OTA update to avoid getting killed by task watchdog (task_wdt) delay(10); } @@ -268,6 +274,9 @@ void OTAComponent::handle_() { delay(10); ESP_LOGI(TAG, "OTA update finished!"); this->status_clear_warning(); +#ifdef USE_OTA_STATE_CALLBACK + this->state_callback_.call(OTA_COMPLETED, 100.0f, 0); +#endif delay(100); // NOLINT App.safe_reboot(); @@ -296,6 +305,9 @@ error: #endif this->status_momentary_error("onerror", 5000); +#ifdef USE_OTA_STATE_CALLBACK + this->state_callback_.call(OTA_ERROR, 0.0f, static_cast(error_code)); +#endif #ifdef ARDUINO_ARCH_ESP8266 global_preferences.prevent_write(false); @@ -400,5 +412,11 @@ void OTAComponent::on_safe_shutdown() { this->clean_rtc(); } +#ifdef USE_OTA_STATE_CALLBACK +void OTAComponent::add_on_state_callback(std::function &&callback) { + this->state_callback_.add(std::move(callback)); +} +#endif + } // namespace ota } // namespace esphome diff --git a/esphome/components/ota/ota_component.h b/esphome/components/ota/ota_component.h index f16725e324..8b5830295e 100644 --- a/esphome/components/ota/ota_component.h +++ b/esphome/components/ota/ota_component.h @@ -2,6 +2,7 @@ #include "esphome/core/component.h" #include "esphome/core/preferences.h" +#include "esphome/core/helpers.h" #include #include @@ -32,6 +33,8 @@ enum OTAResponseTypes { OTA_RESPONSE_ERROR_UNKNOWN = 255, }; +enum OTAState { OTA_COMPLETED = 0, OTA_STARTED, OTA_IN_PROGRESS, OTA_ERROR }; + /// OTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA. class OTAComponent : public Component { public: @@ -49,6 +52,10 @@ class OTAComponent : public Component { bool should_enter_safe_mode(uint8_t num_attempts, uint32_t enable_time); +#ifdef USE_OTA_STATE_CALLBACK + void add_on_state_callback(std::function &&callback); +#endif + // ========== INTERNAL METHODS ========== // (In most use cases you won't need these) void setup() override; @@ -82,6 +89,10 @@ class OTAComponent : public Component { uint32_t safe_mode_rtc_value_; uint8_t safe_mode_num_attempts_; ESPPreferenceObject rtc_; + +#ifdef USE_OTA_STATE_CALLBACK + CallbackManager state_callback_{}; +#endif }; } // namespace ota diff --git a/tests/test1.yaml b/tests/test1.yaml index 1f817f0dab..1c522a23d4 100644 --- a/tests/test1.yaml +++ b/tests/test1.yaml @@ -197,6 +197,24 @@ ota: port: 3286 reboot_timeout: 2min num_attempts: 5 + on_state_change: + then: + lambda: >- + ESP_LOGD("ota", "State %d", state); + on_begin: + then: + logger.log: "OTA begin" + on_progress: + then: + lambda: >- + ESP_LOGD("ota", "Got progress %f", x); + on_end: + then: + logger.log: "OTA end" + on_error: + then: + lambda: >- + ESP_LOGD("ota", "Got error code %d", x); logger: baud_rate: 0