diff --git a/esphome/config_validation.py b/esphome/config_validation.py index 1ad955ffe0..bc3772dad6 100644 --- a/esphome/config_validation.py +++ b/esphome/config_validation.py @@ -1634,11 +1634,10 @@ class GenerateID(Optional): super().__init__(key, default=lambda: None) -def _get_priority_default(*args): - for arg in args: - if arg is not vol.UNDEFINED: - return arg - return vol.UNDEFINED +def _get_platform_key(*args): + key = [CORE.target_platform] + key.extend(args) + return ["_".join(key)] class SplitDefault(Optional): @@ -1648,38 +1647,30 @@ class SplitDefault(Optional): super().__init__(key) self._defaults = {} - priority_mapping = { - "esp32_arduino": ["esp32"], - "esp32_idf": ["esp32"], - "esp32_s2_arduino": ["esp32_s2", "esp32_arduino", "esp32"], - "esp32_s2_idf": ["esp32_s2", "esp32_idf", "esp32"], - "esp32_s3_arduino": ["esp32_s3", "esp32_arduino", "esp32"], - "esp32_s3_idf": ["esp32_s3", "esp32_idf", "esp32"], - "esp32_c3_arduino": ["esp32_c3", "esp32_arduino", "esp32"], - "esp32_c3_idf": ["esp32_c3", "esp32_idf", "esp32"], - } - for platform_key, value in kwargs.items(): - if platform_key not in priority_mapping: - self._defaults[platform_key] = vol.default_factory(value) - for platform_key, priority in priority_mapping.items(): - prioritized_default = _get_priority_default( - *[kwargs.get(p, vol.UNDEFINED) for p in [platform_key] + priority] - ) - self._defaults[platform_key] = vol.default_factory(prioritized_default) + for platform_key, value in kwargs.items(): + self._defaults[platform_key] = vol.default_factory(value) + self._defaults[vol.UNDEFINED] = vol.default_factory(vol.UNDEFINED) @property def default(self): - key = [CORE.target_platform] + keys = [] if CORE.is_esp32: from esphome.components.esp32 import get_esp32_variant from esphome.components.esp32.const import VARIANT_ESP32 variant = get_esp32_variant().replace(VARIANT_ESP32, "").lower() + framework = CORE.target_framework.replace("esp-", "") if variant: - key += [variant] - key += [CORE.target_framework.replace("esp-", "")] - return self._defaults.get("_".join(key), vol.default_factory(vol.UNDEFINED)) + keys += _get_platform_key(variant, framework) + keys += _get_platform_key(variant) + keys += _get_platform_key(framework) + keys += _get_platform_key() + keys += [vol.UNDEFINED] + for key in keys: + if self._defaults.get(key) is not None: + return self._defaults[key] + raise NotImplementedError @default.setter def default(self, value):