[core] Fix some extends cases (#6748)

This commit is contained in:
Jesse Hills 2024-05-16 14:11:54 +12:00 committed by GitHub
parent 247b2eee30
commit 7c243dafb3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 6 deletions

View file

@ -4,7 +4,7 @@ import esphome.config_validation as cv
from esphome import core from esphome import core
from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS
from esphome.yaml_util import ESPHomeDataBase, make_data_base from esphome.yaml_util import ESPHomeDataBase, make_data_base
from esphome.config_helpers import merge_config from esphome.config_helpers import merge_config, Extend, Remove
CODEOWNERS = ["@esphome/core"] CODEOWNERS = ["@esphome/core"]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -105,7 +105,7 @@ def _substitute_item(substitutions, item, path, ignore_missing):
sub = _expand_substitutions(substitutions, item, path, ignore_missing) sub = _expand_substitutions(substitutions, item, path, ignore_missing)
if sub != item: if sub != item:
return sub return sub
elif isinstance(item, core.Lambda): elif isinstance(item, (core.Lambda, Extend, Remove)):
sub = _expand_substitutions(substitutions, item.value, path, ignore_missing) sub = _expand_substitutions(substitutions, item.value, path, ignore_missing)
if sub != item: if sub != item:
item.value = sub item.value = sub

View file

@ -8,6 +8,9 @@ class Extend:
def __str__(self): def __str__(self):
return f"!extend {self.value}" return f"!extend {self.value}"
def __repr__(self):
return f"Extend({self.value})"
def __eq__(self, b): def __eq__(self, b):
""" """
Check if two Extend objects contain the same ID. Check if two Extend objects contain the same ID.
@ -24,6 +27,9 @@ class Remove:
def __str__(self): def __str__(self):
return f"!remove {self.value}" return f"!remove {self.value}"
def __repr__(self):
return f"Remove({self.value})"
def __eq__(self, b): def __eq__(self, b):
""" """
Check if two Remove objects contain the same ID. Check if two Remove objects contain the same ID.
@ -50,14 +56,19 @@ def merge_config(full_old, full_new):
return new return new
res = old.copy() res = old.copy()
ids = { ids = {
v[CONF_ID]: i v_id: i
for i, v in enumerate(res) for i, v in enumerate(res)
if CONF_ID in v and isinstance(v[CONF_ID], str) if (v_id := v.get(CONF_ID)) and isinstance(v_id, str)
} }
extend_ids = {
v_id.value: i
for i, v in enumerate(res)
if (v_id := v.get(CONF_ID)) and isinstance(v_id, Extend)
}
ids_to_delete = [] ids_to_delete = []
for v in new: for v in new:
if CONF_ID in v: if new_id := v.get(CONF_ID):
new_id = v[CONF_ID]
if isinstance(new_id, Extend): if isinstance(new_id, Extend):
new_id = new_id.value new_id = new_id.value
if new_id in ids: if new_id in ids:
@ -69,6 +80,14 @@ def merge_config(full_old, full_new):
if new_id in ids: if new_id in ids:
ids_to_delete.append(ids[new_id]) ids_to_delete.append(ids[new_id])
continue continue
elif (
new_id in extend_ids
): # When a package is extending a non-packaged item
extend_res = res[extend_ids[new_id]]
extend_res[CONF_ID] = new_id
new_v = merge(v, extend_res)
res[extend_ids[new_id]] = new_v
continue
else: else:
ids[new_id] = len(res) ids[new_id] = len(res)
res.append(v) res.append(v)