mirror of
https://github.com/esphome/esphome.git
synced 2024-12-18 03:24:54 +01:00
584 lines
21 KiB
Python
584 lines
21 KiB
Python
from __future__ import annotations
|
|
|
|
import fnmatch
|
|
import functools
|
|
import inspect
|
|
from io import TextIOWrapper
|
|
import logging
|
|
import math
|
|
import os
|
|
from typing import Any
|
|
import uuid
|
|
|
|
import yaml
|
|
from yaml import SafeLoader as PurePythonLoader
|
|
import yaml.constructor
|
|
|
|
try:
|
|
from yaml import CSafeLoader as FastestAvailableSafeLoader
|
|
except ImportError:
|
|
FastestAvailableSafeLoader = PurePythonLoader
|
|
|
|
from esphome import core
|
|
from esphome.config_helpers import Extend, Remove
|
|
from esphome.core import (
|
|
CORE,
|
|
DocumentRange,
|
|
EsphomeError,
|
|
IPAddress,
|
|
Lambda,
|
|
MACAddress,
|
|
TimePeriod,
|
|
)
|
|
from esphome.helpers import add_class_to_obj
|
|
from esphome.util import OrderedDict, filter_yaml_files
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
# Mostly copied from Home Assistant because that code works fine and
|
|
# let's not reinvent the wheel here
|
|
|
|
SECRET_YAML = "secrets.yaml"
|
|
_SECRET_CACHE = {}
|
|
_SECRET_VALUES = {}
|
|
|
|
|
|
class ESPHomeDataBase:
|
|
@property
|
|
def esp_range(self):
|
|
return getattr(self, "_esp_range", None)
|
|
|
|
@property
|
|
def content_offset(self):
|
|
return getattr(self, "_content_offset", 0)
|
|
|
|
def from_node(self, node):
|
|
# pylint: disable=attribute-defined-outside-init
|
|
self._esp_range = DocumentRange.from_marks(node.start_mark, node.end_mark)
|
|
if isinstance(node, yaml.ScalarNode):
|
|
if node.style is not None and node.style in "|>":
|
|
self._content_offset = 1
|
|
|
|
def from_database(self, database):
|
|
# pylint: disable=attribute-defined-outside-init
|
|
self._esp_range = database.esp_range
|
|
self._content_offset = database.content_offset
|
|
|
|
|
|
class ESPForceValue:
|
|
pass
|
|
|
|
|
|
def make_data_base(value, from_database: ESPHomeDataBase = None):
|
|
try:
|
|
value = add_class_to_obj(value, ESPHomeDataBase)
|
|
if from_database is not None:
|
|
value.from_database(from_database)
|
|
return value
|
|
except TypeError:
|
|
# Adding class failed, ignore error
|
|
return value
|
|
|
|
|
|
def _add_data_ref(fn):
|
|
@functools.wraps(fn)
|
|
def wrapped(loader, node):
|
|
res = fn(loader, node)
|
|
# newer PyYAML versions use generators, resolve them
|
|
if inspect.isgenerator(res):
|
|
generator = res
|
|
res = next(generator)
|
|
# Let generator finish
|
|
for _ in generator:
|
|
pass
|
|
res = make_data_base(res)
|
|
if isinstance(res, ESPHomeDataBase):
|
|
res.from_node(node)
|
|
return res
|
|
|
|
return wrapped
|
|
|
|
|
|
class ESPHomeLoaderMixin:
|
|
"""Loader class that keeps track of line numbers."""
|
|
|
|
@_add_data_ref
|
|
def construct_yaml_int(self, node):
|
|
return super().construct_yaml_int(node)
|
|
|
|
@_add_data_ref
|
|
def construct_yaml_float(self, node):
|
|
return super().construct_yaml_float(node)
|
|
|
|
@_add_data_ref
|
|
def construct_yaml_binary(self, node):
|
|
return super().construct_yaml_binary(node)
|
|
|
|
@_add_data_ref
|
|
def construct_yaml_omap(self, node):
|
|
return super().construct_yaml_omap(node)
|
|
|
|
@_add_data_ref
|
|
def construct_yaml_str(self, node):
|
|
return super().construct_yaml_str(node)
|
|
|
|
@_add_data_ref
|
|
def construct_yaml_seq(self, node):
|
|
return super().construct_yaml_seq(node)
|
|
|
|
@_add_data_ref
|
|
def construct_yaml_map(self, node):
|
|
"""Traverses the given mapping node and returns a list of constructed key-value pairs."""
|
|
assert isinstance(node, yaml.MappingNode)
|
|
# A list of key-value pairs we find in the current mapping
|
|
pairs = []
|
|
# A list of key-value pairs we find while resolving merges ('<<' key), will be
|
|
# added to pairs in a second pass
|
|
merge_pairs = []
|
|
# A dict of seen keys so far, used to alert the user of duplicate keys and checking
|
|
# which keys to merge.
|
|
# Value of dict items is the start mark of the previous declaration.
|
|
seen_keys = {}
|
|
|
|
for key_node, value_node in node.value:
|
|
# merge key is '<<'
|
|
is_merge_key = key_node.tag == "tag:yaml.org,2002:merge"
|
|
# key has no explicit tag set
|
|
is_default_tag = key_node.tag == "tag:yaml.org,2002:value"
|
|
|
|
if is_default_tag:
|
|
# Default tag for mapping keys is string
|
|
key_node.tag = "tag:yaml.org,2002:str"
|
|
|
|
if not is_merge_key:
|
|
# base case, this is a simple key-value pair
|
|
key = self.construct_object(key_node)
|
|
value = self.construct_object(value_node)
|
|
|
|
# Check if key is hashable
|
|
try:
|
|
hash(key)
|
|
except TypeError:
|
|
# pylint: disable=raise-missing-from
|
|
raise yaml.constructor.ConstructorError(
|
|
f'Invalid key "{key}" (not hashable)', key_node.start_mark
|
|
)
|
|
|
|
key = make_data_base(str(key))
|
|
key.from_node(key_node)
|
|
|
|
# Check if it is a duplicate key
|
|
if key in seen_keys:
|
|
raise yaml.constructor.ConstructorError(
|
|
f'Duplicate key "{key}"',
|
|
key_node.start_mark,
|
|
"NOTE: Previous declaration here:",
|
|
seen_keys[key],
|
|
)
|
|
seen_keys[key] = key_node.start_mark
|
|
|
|
# Add to pairs
|
|
pairs.append((key, value))
|
|
continue
|
|
|
|
# This is a merge key, resolve value and add to merge_pairs
|
|
value = self.construct_object(value_node)
|
|
if isinstance(value, dict):
|
|
# base case, copy directly to merge_pairs
|
|
# direct merge, like "<<: {some_key: some_value}"
|
|
merge_pairs.extend(value.items())
|
|
elif isinstance(value, list):
|
|
# sequence merge, like "<<: [{some_key: some_value}, {other_key: some_value}]"
|
|
for item in value:
|
|
if not isinstance(item, dict):
|
|
raise yaml.constructor.ConstructorError(
|
|
"While constructing a mapping",
|
|
node.start_mark,
|
|
f"Expected a mapping for merging, but found {type(item)}",
|
|
value_node.start_mark,
|
|
)
|
|
merge_pairs.extend(item.items())
|
|
else:
|
|
raise yaml.constructor.ConstructorError(
|
|
"While constructing a mapping",
|
|
node.start_mark,
|
|
f"Expected a mapping or list of mappings for merging, but found {type(value)}",
|
|
value_node.start_mark,
|
|
)
|
|
|
|
if merge_pairs:
|
|
# We found some merge keys along the way, merge them into base pairs
|
|
# https://yaml.org/type/merge.html
|
|
# Construct a new merge set with values overridden by current mapping or earlier
|
|
# sequence entries removed
|
|
for key, value in merge_pairs:
|
|
if key in seen_keys:
|
|
# key already in the current map or from an earlier merge sequence entry,
|
|
# do not override
|
|
#
|
|
# "... each of its key/value pairs is inserted into the current mapping,
|
|
# unless the key already exists in it."
|
|
#
|
|
# "If the value associated with the merge key is a sequence, then this sequence
|
|
# is expected to contain mapping nodes and each of these nodes is merged in
|
|
# turn according to its order in the sequence. Keys in mapping nodes earlier
|
|
# in the sequence override keys specified in later mapping nodes."
|
|
continue
|
|
pairs.append((key, value))
|
|
# Add key node to seen keys, for sequence merge values.
|
|
seen_keys[key] = None
|
|
|
|
return OrderedDict(pairs)
|
|
|
|
@_add_data_ref
|
|
def construct_env_var(self, node):
|
|
args = node.value.split()
|
|
# Check for a default value
|
|
if len(args) > 1:
|
|
return os.getenv(args[0], " ".join(args[1:]))
|
|
if args[0] in os.environ:
|
|
return os.environ[args[0]]
|
|
raise yaml.MarkedYAMLError(
|
|
f"Environment variable '{node.value}' not defined", node.start_mark
|
|
)
|
|
|
|
@property
|
|
def _directory(self):
|
|
return os.path.dirname(self.name)
|
|
|
|
def _rel_path(self, *args):
|
|
return os.path.join(self._directory, *args)
|
|
|
|
@_add_data_ref
|
|
def construct_secret(self, node):
|
|
try:
|
|
secrets = _load_yaml_internal(self._rel_path(SECRET_YAML))
|
|
except EsphomeError as e:
|
|
if self.name == CORE.config_path:
|
|
raise e
|
|
try:
|
|
main_config_dir = os.path.dirname(CORE.config_path)
|
|
main_secret_yml = os.path.join(main_config_dir, SECRET_YAML)
|
|
secrets = _load_yaml_internal(main_secret_yml)
|
|
except EsphomeError as er:
|
|
raise EsphomeError(f"{e}\n{er}") from er
|
|
|
|
if node.value not in secrets:
|
|
raise yaml.MarkedYAMLError(
|
|
f"Secret '{node.value}' not defined", node.start_mark
|
|
)
|
|
val = secrets[node.value]
|
|
_SECRET_VALUES[str(val)] = node.value
|
|
return val
|
|
|
|
@_add_data_ref
|
|
def construct_include(self, node):
|
|
def extract_file_vars(node):
|
|
fields = self.construct_yaml_map(node)
|
|
file = fields.get("file")
|
|
if file is None:
|
|
raise yaml.MarkedYAMLError("Must include 'file'", node.start_mark)
|
|
vars = fields.get("vars")
|
|
if vars:
|
|
vars = {k: str(v) for k, v in vars.items()}
|
|
return file, vars
|
|
|
|
def substitute_vars(config, vars):
|
|
from esphome.components import substitutions
|
|
from esphome.const import CONF_DEFAULTS, CONF_SUBSTITUTIONS
|
|
|
|
org_subs = None
|
|
result = config
|
|
if not isinstance(config, dict):
|
|
# when the included yaml contains a list or a scalar
|
|
# wrap it into an OrderedDict because do_substitution_pass expects it
|
|
result = OrderedDict([("yaml", config)])
|
|
elif CONF_SUBSTITUTIONS in result:
|
|
org_subs = result.pop(CONF_SUBSTITUTIONS)
|
|
|
|
defaults = {}
|
|
if CONF_DEFAULTS in result:
|
|
defaults = result.pop(CONF_DEFAULTS)
|
|
|
|
result[CONF_SUBSTITUTIONS] = vars
|
|
for k, v in defaults.items():
|
|
if k not in result[CONF_SUBSTITUTIONS]:
|
|
result[CONF_SUBSTITUTIONS][k] = v
|
|
|
|
# Ignore missing vars that refer to the top level substitutions
|
|
substitutions.do_substitution_pass(result, None, ignore_missing=True)
|
|
result.pop(CONF_SUBSTITUTIONS)
|
|
|
|
if not isinstance(config, dict):
|
|
result = result["yaml"] # unwrap the result
|
|
elif org_subs:
|
|
result[CONF_SUBSTITUTIONS] = org_subs
|
|
return result
|
|
|
|
if isinstance(node, yaml.nodes.MappingNode):
|
|
file, vars = extract_file_vars(node)
|
|
else:
|
|
file, vars = node.value, None
|
|
|
|
result = _load_yaml_internal(self._rel_path(file))
|
|
if not vars:
|
|
vars = {}
|
|
result = substitute_vars(result, vars)
|
|
return result
|
|
|
|
@_add_data_ref
|
|
def construct_include_dir_list(self, node):
|
|
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
|
|
return [_load_yaml_internal(f) for f in files]
|
|
|
|
@_add_data_ref
|
|
def construct_include_dir_merge_list(self, node):
|
|
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
|
|
merged_list = []
|
|
for fname in files:
|
|
loaded_yaml = _load_yaml_internal(fname)
|
|
if isinstance(loaded_yaml, list):
|
|
merged_list.extend(loaded_yaml)
|
|
return merged_list
|
|
|
|
@_add_data_ref
|
|
def construct_include_dir_named(self, node):
|
|
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
|
|
mapping = OrderedDict()
|
|
for fname in files:
|
|
filename = os.path.splitext(os.path.basename(fname))[0]
|
|
mapping[filename] = _load_yaml_internal(fname)
|
|
return mapping
|
|
|
|
@_add_data_ref
|
|
def construct_include_dir_merge_named(self, node):
|
|
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
|
|
mapping = OrderedDict()
|
|
for fname in files:
|
|
loaded_yaml = _load_yaml_internal(fname)
|
|
if isinstance(loaded_yaml, dict):
|
|
mapping.update(loaded_yaml)
|
|
return mapping
|
|
|
|
@_add_data_ref
|
|
def construct_lambda(self, node):
|
|
return Lambda(str(node.value))
|
|
|
|
@_add_data_ref
|
|
def construct_force(self, node):
|
|
obj = self.construct_scalar(node)
|
|
return add_class_to_obj(obj, ESPForceValue)
|
|
|
|
@_add_data_ref
|
|
def construct_extend(self, node):
|
|
return Extend(str(node.value))
|
|
|
|
@_add_data_ref
|
|
def construct_remove(self, node):
|
|
return Remove(str(node.value))
|
|
|
|
|
|
class ESPHomeLoader(ESPHomeLoaderMixin, FastestAvailableSafeLoader):
|
|
"""Loader class that keeps track of line numbers."""
|
|
|
|
|
|
class ESPHomePurePythonLoader(ESPHomeLoaderMixin, PurePythonLoader):
|
|
"""Loader class that keeps track of line numbers."""
|
|
|
|
|
|
for _loader in (ESPHomeLoader, ESPHomePurePythonLoader):
|
|
_loader.add_constructor("tag:yaml.org,2002:int", _loader.construct_yaml_int)
|
|
_loader.add_constructor("tag:yaml.org,2002:float", _loader.construct_yaml_float)
|
|
_loader.add_constructor("tag:yaml.org,2002:binary", _loader.construct_yaml_binary)
|
|
_loader.add_constructor("tag:yaml.org,2002:omap", _loader.construct_yaml_omap)
|
|
_loader.add_constructor("tag:yaml.org,2002:str", _loader.construct_yaml_str)
|
|
_loader.add_constructor("tag:yaml.org,2002:seq", _loader.construct_yaml_seq)
|
|
_loader.add_constructor("tag:yaml.org,2002:map", _loader.construct_yaml_map)
|
|
_loader.add_constructor("!env_var", _loader.construct_env_var)
|
|
_loader.add_constructor("!secret", _loader.construct_secret)
|
|
_loader.add_constructor("!include", _loader.construct_include)
|
|
_loader.add_constructor("!include_dir_list", _loader.construct_include_dir_list)
|
|
_loader.add_constructor(
|
|
"!include_dir_merge_list", _loader.construct_include_dir_merge_list
|
|
)
|
|
_loader.add_constructor("!include_dir_named", _loader.construct_include_dir_named)
|
|
_loader.add_constructor(
|
|
"!include_dir_merge_named", _loader.construct_include_dir_merge_named
|
|
)
|
|
_loader.add_constructor("!lambda", _loader.construct_lambda)
|
|
_loader.add_constructor("!force", _loader.construct_force)
|
|
_loader.add_constructor("!extend", _loader.construct_extend)
|
|
_loader.add_constructor("!remove", _loader.construct_remove)
|
|
|
|
|
|
def load_yaml(fname: str, clear_secrets: bool = True) -> Any:
|
|
if clear_secrets:
|
|
_SECRET_VALUES.clear()
|
|
_SECRET_CACHE.clear()
|
|
return _load_yaml_internal(fname)
|
|
|
|
|
|
def parse_yaml(file_name: str, file_handle: TextIOWrapper) -> Any:
|
|
"""Parse a YAML file."""
|
|
try:
|
|
return _load_yaml_internal_with_type(ESPHomeLoader, file_name, file_handle)
|
|
except EsphomeError:
|
|
# Loading failed, so we now load with the Python loader which has more
|
|
# readable exceptions
|
|
# Rewind the stream so we can try again
|
|
file_handle.seek(0, 0)
|
|
return _load_yaml_internal_with_type(
|
|
ESPHomePurePythonLoader, file_name, file_handle
|
|
)
|
|
|
|
|
|
def _load_yaml_internal(fname: str) -> Any:
|
|
"""Load a YAML file."""
|
|
try:
|
|
with open(fname, encoding="utf-8") as f_handle:
|
|
return parse_yaml(fname, f_handle)
|
|
except (UnicodeDecodeError, OSError) as err:
|
|
raise EsphomeError(f"Error reading file {fname}: {err}") from err
|
|
|
|
|
|
def _load_yaml_internal_with_type(
|
|
loader_type: type[ESPHomeLoader] | type[ESPHomePurePythonLoader],
|
|
fname: str,
|
|
content: TextIOWrapper,
|
|
) -> Any:
|
|
"""Load a YAML file."""
|
|
loader = loader_type(content)
|
|
loader.name = fname
|
|
try:
|
|
return loader.get_single_data() or OrderedDict()
|
|
except yaml.YAMLError as exc:
|
|
raise EsphomeError(exc) from exc
|
|
finally:
|
|
loader.dispose()
|
|
|
|
|
|
def dump(dict_, show_secrets=False):
|
|
"""Dump YAML to a string and remove null."""
|
|
if show_secrets:
|
|
_SECRET_VALUES.clear()
|
|
_SECRET_CACHE.clear()
|
|
return yaml.dump(
|
|
dict_, default_flow_style=False, allow_unicode=True, Dumper=ESPHomeDumper
|
|
)
|
|
|
|
|
|
def _is_file_valid(name):
|
|
"""Decide if a file is valid."""
|
|
return not name.startswith(".")
|
|
|
|
|
|
def _find_files(directory, pattern):
|
|
"""Recursively load files in a directory."""
|
|
for root, dirs, files in os.walk(directory, topdown=True):
|
|
dirs[:] = [d for d in dirs if _is_file_valid(d)]
|
|
for basename in files:
|
|
if _is_file_valid(basename) and fnmatch.fnmatch(basename, pattern):
|
|
filename = os.path.join(root, basename)
|
|
yield filename
|
|
|
|
|
|
def is_secret(value):
|
|
try:
|
|
return _SECRET_VALUES[str(value)]
|
|
except (KeyError, ValueError):
|
|
return None
|
|
|
|
|
|
class ESPHomeDumper(yaml.SafeDumper):
|
|
def represent_mapping(self, tag, mapping, flow_style=None):
|
|
value = []
|
|
node = yaml.MappingNode(tag, value, flow_style=flow_style)
|
|
if self.alias_key is not None:
|
|
self.represented_objects[self.alias_key] = node
|
|
best_style = True
|
|
if hasattr(mapping, "items"):
|
|
mapping = list(mapping.items())
|
|
for item_key, item_value in mapping:
|
|
node_key = self.represent_data(item_key)
|
|
node_value = self.represent_data(item_value)
|
|
if not (isinstance(node_key, yaml.ScalarNode) and not node_key.style):
|
|
best_style = False
|
|
if not (isinstance(node_value, yaml.ScalarNode) and not node_value.style):
|
|
best_style = False
|
|
value.append((node_key, node_value))
|
|
if flow_style is None:
|
|
if self.default_flow_style is not None:
|
|
node.flow_style = self.default_flow_style
|
|
else:
|
|
node.flow_style = best_style
|
|
return node
|
|
|
|
def represent_secret(self, value):
|
|
return self.represent_scalar(tag="!secret", value=_SECRET_VALUES[str(value)])
|
|
|
|
def represent_stringify(self, value):
|
|
if is_secret(value):
|
|
return self.represent_secret(value)
|
|
return self.represent_scalar(tag="tag:yaml.org,2002:str", value=str(value))
|
|
|
|
# pylint: disable=arguments-renamed
|
|
def represent_bool(self, value):
|
|
return self.represent_scalar(
|
|
"tag:yaml.org,2002:bool", "true" if value else "false"
|
|
)
|
|
|
|
# pylint: disable=arguments-renamed
|
|
def represent_int(self, value):
|
|
if is_secret(value):
|
|
return self.represent_secret(value)
|
|
return self.represent_scalar(tag="tag:yaml.org,2002:int", value=str(value))
|
|
|
|
# pylint: disable=arguments-renamed
|
|
def represent_float(self, value):
|
|
if is_secret(value):
|
|
return self.represent_secret(value)
|
|
if math.isnan(value):
|
|
value = ".nan"
|
|
elif math.isinf(value):
|
|
value = ".inf" if value > 0 else "-.inf"
|
|
else:
|
|
value = str(repr(value)).lower()
|
|
# Note that in some cases `repr(data)` represents a float number
|
|
# without the decimal parts. For instance:
|
|
# >>> repr(1e17)
|
|
# '1e17'
|
|
# Unfortunately, this is not a valid float representation according
|
|
# to the definition of the `!!float` tag. We fix this by adding
|
|
# '.0' before the 'e' symbol.
|
|
if "." not in value and "e" in value:
|
|
value = value.replace("e", ".0e", 1)
|
|
return self.represent_scalar(tag="tag:yaml.org,2002:float", value=value)
|
|
|
|
def represent_lambda(self, value):
|
|
if is_secret(value.value):
|
|
return self.represent_secret(value.value)
|
|
return self.represent_scalar(tag="!lambda", value=value.value, style="|")
|
|
|
|
def represent_id(self, value):
|
|
if is_secret(value.id):
|
|
return self.represent_secret(value.id)
|
|
return self.represent_stringify(value.id)
|
|
|
|
|
|
ESPHomeDumper.add_multi_representer(
|
|
dict, lambda dumper, value: dumper.represent_mapping("tag:yaml.org,2002:map", value)
|
|
)
|
|
ESPHomeDumper.add_multi_representer(
|
|
list,
|
|
lambda dumper, value: dumper.represent_sequence("tag:yaml.org,2002:seq", value),
|
|
)
|
|
ESPHomeDumper.add_multi_representer(bool, ESPHomeDumper.represent_bool)
|
|
ESPHomeDumper.add_multi_representer(str, ESPHomeDumper.represent_stringify)
|
|
ESPHomeDumper.add_multi_representer(int, ESPHomeDumper.represent_int)
|
|
ESPHomeDumper.add_multi_representer(float, ESPHomeDumper.represent_float)
|
|
ESPHomeDumper.add_multi_representer(IPAddress, ESPHomeDumper.represent_stringify)
|
|
ESPHomeDumper.add_multi_representer(MACAddress, ESPHomeDumper.represent_stringify)
|
|
ESPHomeDumper.add_multi_representer(TimePeriod, ESPHomeDumper.represent_stringify)
|
|
ESPHomeDumper.add_multi_representer(Lambda, ESPHomeDumper.represent_lambda)
|
|
ESPHomeDumper.add_multi_representer(core.ID, ESPHomeDumper.represent_id)
|
|
ESPHomeDumper.add_multi_representer(uuid.UUID, ESPHomeDumper.represent_stringify)
|