Add ability to await safe mode in codegen (#4529)

* Add ability to await OTA safe mode

* Make pylint happy
This commit is contained in:
Oxan van Leeuwen 2023-03-07 22:29:45 +01:00 committed by GitHub
parent b29cc58144
commit ceebe14628
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 0 deletions

View file

@ -47,6 +47,7 @@ from esphome.cpp_helpers import ( # noqa
build_registry_list, build_registry_list,
extract_registry_entry_config, extract_registry_entry_config,
register_parented, register_parented,
past_safe_mode,
) )
from esphome.cpp_types import ( # noqa from esphome.cpp_types import ( # noqa
global_ns, global_ns,

View file

@ -10,6 +10,8 @@ from esphome.const import (
CONF_REBOOT_TIMEOUT, CONF_REBOOT_TIMEOUT,
CONF_SAFE_MODE, CONF_SAFE_MODE,
CONF_TRIGGER_ID, CONF_TRIGGER_ID,
CONF_OTA,
KEY_PAST_SAFE_MODE,
) )
from esphome.core import CORE, coroutine_with_priority from esphome.core import CORE, coroutine_with_priority
@ -76,6 +78,8 @@ CONFIG_SCHEMA = cv.Schema(
@coroutine_with_priority(50.0) @coroutine_with_priority(50.0)
async def to_code(config): async def to_code(config):
CORE.data[CONF_OTA] = {}
var = cg.new_Pvariable(config[CONF_ID]) var = cg.new_Pvariable(config[CONF_ID])
cg.add(var.set_port(config[CONF_PORT])) cg.add(var.set_port(config[CONF_PORT]))
cg.add_define("USE_OTA") cg.add_define("USE_OTA")
@ -90,6 +94,7 @@ async def to_code(config):
config[CONF_NUM_ATTEMPTS], config[CONF_REBOOT_TIMEOUT] config[CONF_NUM_ATTEMPTS], config[CONF_REBOOT_TIMEOUT]
) )
cg.add(RawExpression(f"if ({condition}) return")) cg.add(RawExpression(f"if ({condition}) return"))
CORE.data[CONF_OTA][KEY_PAST_SAFE_MODE] = True
if CORE.is_esp32 and CORE.using_arduino: if CORE.is_esp32 and CORE.using_arduino:
cg.add_library("Update", None) cg.add_library("Update", None)

View file

@ -1023,6 +1023,7 @@ KEY_TARGET_FRAMEWORK = "target_framework"
KEY_FRAMEWORK_VERSION = "framework_version" KEY_FRAMEWORK_VERSION = "framework_version"
KEY_NAME = "name" KEY_NAME = "name"
KEY_VARIANT = "variant" KEY_VARIANT = "variant"
KEY_PAST_SAFE_MODE = "past_safe_mode"
# Entity categories # Entity categories
ENTITY_CATEGORY_NONE = "" ENTITY_CATEGORY_NONE = ""

View file

@ -9,9 +9,13 @@ from esphome.const import (
CONF_SETUP_PRIORITY, CONF_SETUP_PRIORITY,
CONF_UPDATE_INTERVAL, CONF_UPDATE_INTERVAL,
CONF_TYPE_ID, CONF_TYPE_ID,
CONF_OTA,
CONF_SAFE_MODE,
KEY_PAST_SAFE_MODE,
) )
from esphome.core import coroutine, ID, CORE from esphome.core import coroutine, ID, CORE
from esphome.coroutine import FakeAwaitable
from esphome.types import ConfigType, ConfigFragmentType from esphome.types import ConfigType, ConfigFragmentType
from esphome.cpp_generator import add, get_variable from esphome.cpp_generator import add, get_variable
from esphome.cpp_types import App from esphome.cpp_types import App
@ -127,3 +131,19 @@ async def build_registry_list(registry, config):
action = await build_registry_entry(registry, conf) action = await build_registry_entry(registry, conf)
actions.append(action) actions.append(action)
return actions return actions
async def past_safe_mode():
safe_mode_enabled = (
CONF_OTA in CORE.config and CORE.config[CONF_OTA][CONF_SAFE_MODE]
)
if not safe_mode_enabled:
return
def _safe_mode_generator():
while True:
if CORE.data.get(CONF_OTA, {}).get(KEY_PAST_SAFE_MODE, False):
return
yield
return await FakeAwaitable(_safe_mode_generator())