diff --git a/esphome/codegen.py b/esphome/codegen.py index c926c94070..43b44256e2 100644 --- a/esphome/codegen.py +++ b/esphome/codegen.py @@ -47,6 +47,7 @@ from esphome.cpp_helpers import ( # noqa build_registry_list, extract_registry_entry_config, register_parented, + past_safe_mode, ) from esphome.cpp_types import ( # noqa global_ns, diff --git a/esphome/components/ota/__init__.py b/esphome/components/ota/__init__.py index 32ea1fd363..a966157ffa 100644 --- a/esphome/components/ota/__init__.py +++ b/esphome/components/ota/__init__.py @@ -10,6 +10,8 @@ from esphome.const import ( CONF_REBOOT_TIMEOUT, CONF_SAFE_MODE, CONF_TRIGGER_ID, + CONF_OTA, + KEY_PAST_SAFE_MODE, ) from esphome.core import CORE, coroutine_with_priority @@ -76,6 +78,8 @@ CONFIG_SCHEMA = cv.Schema( @coroutine_with_priority(50.0) async def to_code(config): + CORE.data[CONF_OTA] = {} + var = cg.new_Pvariable(config[CONF_ID]) cg.add(var.set_port(config[CONF_PORT])) cg.add_define("USE_OTA") @@ -90,6 +94,7 @@ async def to_code(config): config[CONF_NUM_ATTEMPTS], config[CONF_REBOOT_TIMEOUT] ) cg.add(RawExpression(f"if ({condition}) return")) + CORE.data[CONF_OTA][KEY_PAST_SAFE_MODE] = True if CORE.is_esp32 and CORE.using_arduino: cg.add_library("Update", None) diff --git a/esphome/const.py b/esphome/const.py index 289a59b424..4d6c7623bc 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -1023,6 +1023,7 @@ KEY_TARGET_FRAMEWORK = "target_framework" KEY_FRAMEWORK_VERSION = "framework_version" KEY_NAME = "name" KEY_VARIANT = "variant" +KEY_PAST_SAFE_MODE = "past_safe_mode" # Entity categories ENTITY_CATEGORY_NONE = "" diff --git a/esphome/cpp_helpers.py b/esphome/cpp_helpers.py index 02d339441f..ab5231e055 100644 --- a/esphome/cpp_helpers.py +++ b/esphome/cpp_helpers.py @@ -9,9 +9,13 @@ from esphome.const import ( CONF_SETUP_PRIORITY, CONF_UPDATE_INTERVAL, CONF_TYPE_ID, + CONF_OTA, + CONF_SAFE_MODE, + KEY_PAST_SAFE_MODE, ) from esphome.core import coroutine, ID, CORE +from esphome.coroutine import FakeAwaitable from esphome.types import ConfigType, ConfigFragmentType from esphome.cpp_generator import add, get_variable from esphome.cpp_types import App @@ -127,3 +131,19 @@ async def build_registry_list(registry, config): action = await build_registry_entry(registry, conf) actions.append(action) return actions + + +async def past_safe_mode(): + safe_mode_enabled = ( + CONF_OTA in CORE.config and CORE.config[CONF_OTA][CONF_SAFE_MODE] + ) + if not safe_mode_enabled: + return + + def _safe_mode_generator(): + while True: + if CORE.data.get(CONF_OTA, {}).get(KEY_PAST_SAFE_MODE, False): + return + yield + + return await FakeAwaitable(_safe_mode_generator())