Add variable substitutions for !include (#3510)

This commit is contained in:
jimtng 2022-05-31 14:45:18 +10:00 committed by GitHub
parent 708672ec7e
commit 5aa42e5e66
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 95 additions and 17 deletions

View file

@ -48,7 +48,7 @@ VARIABLE_PROG = re.compile(
) )
def _expand_substitutions(substitutions, value, path): def _expand_substitutions(substitutions, value, path, ignore_missing):
if "$" not in value: if "$" not in value:
return value return value
@ -66,6 +66,7 @@ def _expand_substitutions(substitutions, value, path):
if name.startswith("{") and name.endswith("}"): if name.startswith("{") and name.endswith("}"):
name = name[1:-1] name = name[1:-1]
if name not in substitutions: if name not in substitutions:
if not ignore_missing:
_LOGGER.warning( _LOGGER.warning(
"Found '%s' (see %s) which looks like a substitution, but '%s' was " "Found '%s' (see %s) which looks like a substitution, but '%s' was "
"not declared", "not declared",
@ -92,37 +93,37 @@ def _expand_substitutions(substitutions, value, path):
return value return value
def _substitute_item(substitutions, item, path): def _substitute_item(substitutions, item, path, ignore_missing):
if isinstance(item, list): if isinstance(item, list):
for i, it in enumerate(item): for i, it in enumerate(item):
sub = _substitute_item(substitutions, it, path + [i]) sub = _substitute_item(substitutions, it, path + [i], ignore_missing)
if sub is not None: if sub is not None:
item[i] = sub item[i] = sub
elif isinstance(item, dict): elif isinstance(item, dict):
replace_keys = [] replace_keys = []
for k, v in item.items(): for k, v in item.items():
if path or k != CONF_SUBSTITUTIONS: if path or k != CONF_SUBSTITUTIONS:
sub = _substitute_item(substitutions, k, path + [k]) sub = _substitute_item(substitutions, k, path + [k], ignore_missing)
if sub is not None: if sub is not None:
replace_keys.append((k, sub)) replace_keys.append((k, sub))
sub = _substitute_item(substitutions, v, path + [k]) sub = _substitute_item(substitutions, v, path + [k], ignore_missing)
if sub is not None: if sub is not None:
item[k] = sub item[k] = sub
for old, new in replace_keys: for old, new in replace_keys:
item[new] = merge_config(item.get(old), item.get(new)) item[new] = merge_config(item.get(old), item.get(new))
del item[old] del item[old]
elif isinstance(item, str): elif isinstance(item, str):
sub = _expand_substitutions(substitutions, item, path) sub = _expand_substitutions(substitutions, item, path, ignore_missing)
if sub != item: if sub != item:
return sub return sub
elif isinstance(item, core.Lambda): elif isinstance(item, core.Lambda):
sub = _expand_substitutions(substitutions, item.value, path) sub = _expand_substitutions(substitutions, item.value, path, ignore_missing)
if sub != item: if sub != item:
item.value = sub item.value = sub
return None return None
def do_substitution_pass(config, command_line_substitutions): def do_substitution_pass(config, command_line_substitutions, ignore_missing=False):
if CONF_SUBSTITUTIONS not in config and not command_line_substitutions: if CONF_SUBSTITUTIONS not in config and not command_line_substitutions:
return return
@ -151,4 +152,4 @@ def do_substitution_pass(config, command_line_substitutions):
config[CONF_SUBSTITUTIONS] = substitutions config[CONF_SUBSTITUTIONS] = substitutions
# Move substitutions to the first place to replace substitutions in them correctly # Move substitutions to the first place to replace substitutions in them correctly
config.move_to_end(CONF_SUBSTITUTIONS, False) config.move_to_end(CONF_SUBSTITUTIONS, False)
_substitute_item(substitutions, config, []) _substitute_item(substitutions, config, [], ignore_missing)

View file

@ -251,7 +251,49 @@ class ESPHomeLoader(yaml.SafeLoader): # pylint: disable=too-many-ancestors
@_add_data_ref @_add_data_ref
def construct_include(self, node): def construct_include(self, node):
return _load_yaml_internal(self._rel_path(node.value)) def extract_file_vars(node):
fields = self.construct_yaml_map(node)
file = fields.get("file")
if file is None:
raise yaml.MarkedYAMLError("Must include 'file'", node.start_mark)
vars = fields.get("vars")
if vars:
vars = {k: str(v) for k, v in vars.items()}
return file, vars
def substitute_vars(config, vars):
from esphome.const import CONF_SUBSTITUTIONS
from esphome.components import substitutions
org_subs = None
result = config
if not isinstance(config, dict):
# when the included yaml contains a list or a scalar
# wrap it into an OrderedDict because do_substitution_pass expects it
result = OrderedDict([("yaml", config)])
elif CONF_SUBSTITUTIONS in result:
org_subs = result.pop(CONF_SUBSTITUTIONS)
result[CONF_SUBSTITUTIONS] = vars
# Ignore missing vars that refer to the top level substitutions
substitutions.do_substitution_pass(result, None, ignore_missing=True)
result.pop(CONF_SUBSTITUTIONS)
if not isinstance(config, dict):
result = result["yaml"] # unwrap the result
elif org_subs:
result[CONF_SUBSTITUTIONS] = org_subs
return result
if isinstance(node, yaml.nodes.MappingNode):
file, vars = extract_file_vars(node)
else:
file, vars = node.value, None
result = _load_yaml_internal(self._rel_path(file))
if vars:
result = substitute_vars(result, vars)
return result
@_add_data_ref @_add_data_ref
def construct_include_dir_list(self, node): def construct_include_dir_list(self, node):

View file

@ -0,0 +1,2 @@
---
ssid: ${name}

View file

@ -0,0 +1,2 @@
---
- ${var1}

View file

@ -0,0 +1 @@
${var1}

View file

@ -0,0 +1,17 @@
---
substitutions:
name: original
wifi: !include
file: includes/included.yaml
vars:
name: my_custom_ssid
esphome:
# should be substituted as 'original', not overwritten by vars in the !include above
name: ${name}
name_add_mac_suffix: true
platform: esp8266
board: !include { file: includes/scalar.yaml, vars: { var1: nodemcu } }
libraries: !include { file: includes/list.yaml, vars: { var1: Wire } }

View file

@ -0,0 +1,13 @@
from esphome import yaml_util
from esphome.components import substitutions
def test_include_with_vars(fixture_path):
yaml_file = fixture_path / "yaml_util" / "includetest.yaml"
actual = yaml_util.load_yaml(yaml_file)
substitutions.do_substitution_pass(actual, None)
assert actual["esphome"]["name"] == "original"
assert actual["esphome"]["libraries"][0] == "Wire"
assert actual["esphome"]["board"] == "nodemcu"
assert actual["wifi"]["ssid"] == "my_custom_ssid"