From d09dff3ae31285e7afff255e679b0cf637dda5d6 Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Sat, 7 Dec 2019 13:43:51 +0100 Subject: [PATCH] Clean up YAML Mapping construction (#910) * Clean up YAML Mapping construction Fixes https://github.com/esphome/issues/issues/902 * Clean up DataBase * Update error messages --- esphome/config.py | 3 +- esphome/helpers.py | 10 +-- esphome/yaml_util.py | 180 ++++++++++++++++++++----------------------- 3 files changed, 89 insertions(+), 104 deletions(-) diff --git a/esphome/config.py b/esphome/config.py index 53449c3e85..027c28bc5d 100644 --- a/esphome/config.py +++ b/esphome/config.py @@ -663,8 +663,7 @@ class InvalidYAMLError(EsphomeError): except UnicodeDecodeError: base = repr(base_exc) base = decode_text(base) - message = u"Invalid YAML syntax. Please see YAML syntax reference or use an " \ - u"online YAML syntax validator:\n\n{}".format(base) + message = u"Invalid YAML syntax:\n\n{}".format(base) super(InvalidYAMLError, self).__init__(message) self.base_exc = base_exc diff --git a/esphome/helpers.py b/esphome/helpers.py index e91b13a735..179452c353 100644 --- a/esphome/helpers.py +++ b/esphome/helpers.py @@ -266,11 +266,11 @@ def file_compare(path1, path2): # A dict of types that need to be converted to heaptypes before a class can be added # to the object _TYPE_OVERLOADS = { - int: type('int', (int,), dict()), - float: type('float', (float,), dict()), - str: type('str', (str,), dict()), - dict: type('dict', (str,), dict()), - list: type('list', (list,), dict()), + int: type('EInt', (int,), dict()), + float: type('EFloat', (float,), dict()), + str: type('EStr', (str,), dict()), + dict: type('EDict', (str,), dict()), + list: type('EList', (list,), dict()), } if IS_PY2: diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index 69f3c70ede..0e5b4593e9 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -43,7 +43,11 @@ class ESPForceValue(object): def make_data_base(value): - return add_class_to_obj(value, ESPHomeDataBase) + try: + return add_class_to_obj(value, ESPHomeDataBase) + except TypeError: + # Adding class failed, ignore error + return value def _add_data_ref(fn): @@ -92,50 +96,82 @@ class ESPHomeLoader(yaml.SafeLoader): # pylint: disable=too-many-ancestors def construct_yaml_seq(self, node): return super(ESPHomeLoader, self).construct_yaml_seq(node) - def custom_flatten_mapping(self, node): - merge = [] - index = 0 - while index < len(node.value): - key_node, value_node = node.value[index] - if key_node.tag == 'tag:yaml.org,2002:merge': - del node.value[index] - if isinstance(value_node, yaml.MappingNode): - self.custom_flatten_mapping(value_node) - merge.extend(value_node.value) - elif isinstance(value_node, yaml.SequenceNode): - submerge = [] - for subnode in value_node.value: - if not isinstance(subnode, yaml.MappingNode): - raise yaml.constructor.ConstructorError( - "while constructing a mapping", node.start_mark, - "expected a mapping for merging, but found {}".format(subnode.id), - subnode.start_mark) - self.custom_flatten_mapping(subnode) - submerge.append(subnode.value) - submerge.reverse() - for value in submerge: - merge.extend(value) - else: - raise yaml.constructor.ConstructorError( - "while constructing a mapping", node.start_mark, - "expected a mapping or list of mappings for merging, " - "but found {}".format(value_node.id), value_node.start_mark) - elif key_node.tag == 'tag:yaml.org,2002:value': - key_node.tag = 'tag:yaml.org,2002:str' - index += 1 - else: - index += 1 - if merge: - # https://yaml.org/type/merge.html - # Generate a set of keys that should override values in `merge`. - haystack = {key.value for (key, _) in node.value} + @_add_data_ref + def construct_yaml_map(self, node): + """Traverses the given mapping node and returns a list of constructed key-value pairs.""" + assert isinstance(node, yaml.MappingNode) + # A list of key-value pairs we find in the current mapping + pairs = [] + # A list of key-value pairs we find while resolving merges ('<<' key), will be + # added to pairs in a second pass + merge_pairs = [] + # A dict of seen keys so far, used to alert the user of duplicate keys and checking + # which keys to merge. + # Value of dict items is the start mark of the previous declaration. + seen_keys = {} + for key_node, value_node in node.value: + # merge key is '<<' + is_merge_key = key_node.tag == 'tag:yaml.org,2002:merge' + # key has no explicit tag set + is_default_tag = key_node.tag == 'tag:yaml.org,2002:value' + + if is_default_tag: + # Default tag for mapping keys is string + key_node.tag = 'tag:yaml.org,2002:str' + + if not is_merge_key: + # base case, this is a simple key-value pair + key = self.construct_object(key_node) + value = self.construct_object(value_node) + + # Check if key is hashable + try: + hash(key) + except TypeError: + raise yaml.constructor.ConstructorError( + 'Invalid key "{}" (not hashable)'.format(key), key_node.start_mark) + + # Check if it is a duplicate key + if key in seen_keys: + raise yaml.constructor.ConstructorError( + 'Duplicate key "{}"'.format(key), key_node.start_mark, + 'NOTE: Previous declaration here:', seen_keys[key], + ) + seen_keys[key] = key_node.start_mark + + # Add to pairs + pairs.append((key, value)) + continue + + # This is a merge key, resolve value and add to merge_pairs + value = self.construct_object(value_node) + if isinstance(value, dict): + # base case, copy directly to merge_pairs + # direct merge, like "<<: {some_key: some_value}" + merge_pairs.extend(value.items()) + elif isinstance(value, list): + # sequence merge, like "<<: [{some_key: some_value}, {other_key: some_value}]" + for item in value: + if not isinstance(item, dict): + raise yaml.constructor.ConstructorError( + "While constructing a mapping", node.start_mark, + "Expected a mapping for merging, but found {}".format(type(item)), + value_node.start_mark) + merge_pairs.extend(item.items()) + else: + raise yaml.constructor.ConstructorError( + "While constructing a mapping", node.start_mark, + "Expected a mapping or list of mappings for merging, " + "but found {}".format(type(value)), value_node.start_mark) + + if merge_pairs: + # We found some merge keys along the way, merge them into base pairs + # https://yaml.org/type/merge.html # Construct a new merge set with values overridden by current mapping or earlier # sequence entries removed - new_merge = [] - - for key, value in merge: - if key.value in haystack: + for key, value in merge_pairs: + if key in seen_keys: # key already in the current map or from an earlier merge sequence entry, # do not override # @@ -147,59 +183,11 @@ class ESPHomeLoader(yaml.SafeLoader): # pylint: disable=too-many-ancestors # turn according to its order in the sequence. Keys in mapping nodes earlier # in the sequence override keys specified in later mapping nodes." continue - new_merge.append((key, value)) - # Add key node to haystack, for sequence merge values. - haystack.add(key.value) - - # Merge - node.value = new_merge + node.value - - def custom_construct_pairs(self, node): - pairs = [] - for kv in node.value: - if isinstance(kv, yaml.ScalarNode): - obj = self.construct_object(kv) - if not isinstance(obj, dict): - raise EsphomeError( - "Expected mapping for anchored include tag, got {}".format(type(obj))) - for key, value in obj.items(): - pairs.append((key, value)) - else: - key_node, value_node = kv - key = self.construct_object(key_node) - value = self.construct_object(value_node) pairs.append((key, value)) + # Add key node to seen keys, for sequence merge values. + seen_keys[key] = None - return pairs - - @_add_data_ref - def construct_yaml_map(self, node): - self.custom_flatten_mapping(node) - nodes = self.custom_construct_pairs(node) - - seen = {} - for (key, _), nv in zip(nodes, node.value): - if isinstance(nv, yaml.ScalarNode): - line = nv.start_mark.line - else: - line = nv[0].start_mark.line - - try: - hash(key) - except TypeError: - raise yaml.MarkedYAMLError( - context="invalid key: \"{}\"".format(key), - context_mark=yaml.Mark(self.name, 0, line, -1, None, None) - ) - - if key in seen: - raise yaml.MarkedYAMLError( - context="duplicate key: \"{}\"".format(key), - context_mark=yaml.Mark(self.name, 0, line, -1, None, None) - ) - seen[key] = line - - return OrderedDict(nodes) + return OrderedDict(pairs) @_add_data_ref def construct_env_var(self, node): @@ -210,8 +198,7 @@ class ESPHomeLoader(yaml.SafeLoader): # pylint: disable=too-many-ancestors if args[0] in os.environ: return os.environ[args[0]] raise yaml.MarkedYAMLError( - context=u"Environment variable '{}' not defined".format(node.value), - context_mark=node.start_mark + u"Environment variable '{}' not defined".format(node.value), node.start_mark ) @property @@ -226,8 +213,7 @@ class ESPHomeLoader(yaml.SafeLoader): # pylint: disable=too-many-ancestors secrets = _load_yaml_internal(self._rel_path(SECRET_YAML)) if node.value not in secrets: raise yaml.MarkedYAMLError( - context=u"Secret '{}' not defined".format(node.value), - context_mark=node.start_mark + u"Secret '{}' not defined".format(node.value), node.start_mark ) val = secrets[node.value] _SECRET_VALUES[text_type(val)] = node.value