Improve config final validation (#1917)

This commit is contained in:
Otto Winter 2021-06-17 21:54:14 +02:00 committed by GitHub
parent c19b3ecd43
commit 2419bc3678
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 303 additions and 162 deletions

View file

@ -45,6 +45,9 @@ CONFIG_SCHEMA = (
.extend(cv.polling_component_schema("60s")) .extend(cv.polling_component_schema("60s"))
.extend(uart.UART_DEVICE_SCHEMA) .extend(uart.UART_DEVICE_SCHEMA)
) )
FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema(
"cse7766", baud_rate=4800, require_rx=True
)
async def to_code(config): async def to_code(config):
@ -64,9 +67,3 @@ async def to_code(config):
conf = config[CONF_POWER] conf = config[CONF_POWER]
sens = await sensor.new_sensor(conf) sens = await sensor.new_sensor(conf)
cg.add(var.set_power_sensor(sens)) cg.add(var.set_power_sensor(sens))
def validate(config, item_config):
uart.validate_device(
"cse7766", config, item_config, baud_rate=4800, require_tx=False
)

View file

@ -68,6 +68,9 @@ CONFIG_SCHEMA = cv.All(
} }
).extend(uart.UART_DEVICE_SCHEMA) ).extend(uart.UART_DEVICE_SCHEMA)
) )
FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema(
"dfplayer", baud_rate=9600, require_tx=True
)
async def to_code(config): async def to_code(config):
@ -80,12 +83,6 @@ async def to_code(config):
await automation.build_automation(trigger, [], conf) await automation.build_automation(trigger, [], conf)
def validate(config, item_config):
uart.validate_device(
"dfplayer", config, item_config, baud_rate=9600, require_rx=False
)
@automation.register_action( @automation.register_action(
"dfplayer.play_next", "dfplayer.play_next",
NextAction, NextAction,

View file

@ -62,6 +62,7 @@ CONFIG_SCHEMA = (
.extend(cv.polling_component_schema("20s")) .extend(cv.polling_component_schema("20s"))
.extend(uart.UART_DEVICE_SCHEMA) .extend(uart.UART_DEVICE_SCHEMA)
) )
FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema("gps", require_rx=True)
async def to_code(config): async def to_code(config):
@ -95,7 +96,3 @@ async def to_code(config):
# https://platformio.org/lib/show/1655/TinyGPSPlus # https://platformio.org/lib/show/1655/TinyGPSPlus
cg.add_library("1655", "1.0.2") # TinyGPSPlus, has name conflict cg.add_library("1655", "1.0.2") # TinyGPSPlus, has name conflict
def validate(config, item_config):
uart.validate_device("gps", config, item_config, require_tx=False)

View file

@ -2,6 +2,7 @@ import urllib.parse as urlparse
import esphome.codegen as cg import esphome.codegen as cg
import esphome.config_validation as cv import esphome.config_validation as cv
import esphome.final_validate as fv
from esphome import automation from esphome import automation
from esphome.const import ( from esphome.const import (
CONF_ID, CONF_ID,
@ -14,7 +15,6 @@ from esphome.const import (
CONF_URL, CONF_URL,
) )
from esphome.core import CORE, Lambda from esphome.core import CORE, Lambda
from esphome.core.config import PLATFORMIO_ESP8266_LUT
DEPENDENCIES = ["network"] DEPENDENCIES = ["network"]
AUTO_LOAD = ["json"] AUTO_LOAD = ["json"]
@ -36,29 +36,6 @@ CONF_VERIFY_SSL = "verify_ssl"
CONF_ON_RESPONSE = "on_response" CONF_ON_RESPONSE = "on_response"
def validate_framework(config):
if CORE.is_esp32:
return config
version = "RECOMMENDED"
if CONF_ARDUINO_VERSION in CORE.raw_config[CONF_ESPHOME]:
version = CORE.raw_config[CONF_ESPHOME][CONF_ARDUINO_VERSION]
if version in ["LATEST", "DEV"]:
return config
framework = (
PLATFORMIO_ESP8266_LUT[version]
if version in PLATFORMIO_ESP8266_LUT
else version
)
if framework < ARDUINO_VERSION_ESP8266["2.5.1"]:
raise cv.Invalid(
"This component is not supported on arduino framework version below 2.5.1"
)
return config
def validate_url(value): def validate_url(value):
value = cv.string(value) value = cv.string(value)
try: try:
@ -92,19 +69,36 @@ def validate_secure_url(config):
return config return config
CONFIG_SCHEMA = ( CONFIG_SCHEMA = cv.Schema(
cv.Schema(
{ {
cv.GenerateID(): cv.declare_id(HttpRequestComponent), cv.GenerateID(): cv.declare_id(HttpRequestComponent),
cv.Optional(CONF_USERAGENT, "ESPHome"): cv.string, cv.Optional(CONF_USERAGENT, "ESPHome"): cv.string,
cv.Optional( cv.Optional(CONF_TIMEOUT, default="5s"): cv.positive_time_period_milliseconds,
CONF_TIMEOUT, default="5s"
): cv.positive_time_period_milliseconds,
} }
).extend(cv.COMPONENT_SCHEMA)
def validate_framework(config):
if CORE.is_esp32:
return
# only for ESP8266
path = [CONF_ESPHOME, CONF_ARDUINO_VERSION]
version: str = fv.full_config.get().get_config_for_path(path)
reverse_map = {v: k for k, v in ARDUINO_VERSION_ESP8266.items()}
framework_version = reverse_map.get(version)
if framework_version is None or framework_version == "dev":
return
if framework_version < "2.5.1":
raise cv.Invalid(
"This component is not supported on arduino framework version below 2.5.1",
path=[cv.ROOT_CONFIG_PATH] + path,
) )
.add_extra(validate_framework)
.extend(cv.COMPONENT_SCHEMA)
) FINAL_VALIDATE_SCHEMA = cv.Schema(validate_framework)
async def to_code(config): async def to_code(config):

View file

@ -19,13 +19,12 @@ CONFIG_SCHEMA = cv.All(
).extend(spi.spi_device_schema(cs_pin_required=True)) ).extend(spi.spi_device_schema(cs_pin_required=True))
) )
FINAL_VALIDATE_SCHEMA = spi.final_validate_device_schema(
"rc522_spi", require_miso=True, require_mosi=True
)
async def to_code(config): async def to_code(config):
var = cg.new_Pvariable(config[CONF_ID]) var = cg.new_Pvariable(config[CONF_ID])
await rc522.setup_rc522(var, config) await rc522.setup_rc522(var, config)
await spi.register_spi_device(var, config) await spi.register_spi_device(var, config)
def validate(config, item_config):
# validate given SPI hub is suitable for rc522_spi, it needs both miso and mosi
spi.validate_device("rc522_spi", config, item_config, True, True)

View file

@ -1,6 +1,7 @@
import logging import logging
import esphome.codegen as cg import esphome.codegen as cg
import esphome.config_validation as cv import esphome.config_validation as cv
import esphome.final_validate as fv
from esphome import automation from esphome import automation
from esphome.components.output import FloatOutput from esphome.components.output import FloatOutput
from esphome.const import CONF_ID, CONF_OUTPUT, CONF_PLATFORM, CONF_TRIGGER_ID from esphome.const import CONF_ID, CONF_OUTPUT, CONF_PLATFORM, CONF_TRIGGER_ID
@ -36,12 +37,8 @@ CONFIG_SCHEMA = cv.Schema(
).extend(cv.COMPONENT_SCHEMA) ).extend(cv.COMPONENT_SCHEMA)
def validate(config, item_config): def validate_parent_output_config(value):
# Not adding this to FloatOutput as this is the only component which needs `update_frequency` platform = value.get(CONF_PLATFORM)
parent_config = config.get_config_by_id(item_config[CONF_OUTPUT])
platform = parent_config[CONF_PLATFORM]
PWM_GOOD = ["esp8266_pwm", "ledc"] PWM_GOOD = ["esp8266_pwm", "ledc"]
PWM_BAD = [ PWM_BAD = [
"ac_dimmer ", "ac_dimmer ",
@ -55,14 +52,25 @@ def validate(config, item_config):
] ]
if platform in PWM_BAD: if platform in PWM_BAD:
raise ValueError(f"Component rtttl cannot use {platform} as output component") raise cv.Invalid(f"Component rtttl cannot use {platform} as output component")
if platform not in PWM_GOOD: if platform not in PWM_GOOD:
_LOGGER.warning( _LOGGER.warning(
"Component rtttl is not known to work with the selected output type. Make sure this output supports custom frequency output method." "Component rtttl is not known to work with the selected output type. "
"Make sure this output supports custom frequency output method."
) )
FINAL_VALIDATE_SCHEMA = cv.Schema(
{
cv.Required(CONF_OUTPUT): fv.id_declaration_match_schema(
validate_parent_output_config
)
},
extra=cv.ALLOW_EXTRA,
)
async def to_code(config): async def to_code(config):
var = cg.new_Pvariable(config[CONF_ID]) var = cg.new_Pvariable(config[CONF_ID])
await cg.register_component(var, config) await cg.register_component(var, config)

View file

@ -40,6 +40,9 @@ CONFIG_SCHEMA = cv.All(
.extend(cv.polling_component_schema("5s")) .extend(cv.polling_component_schema("5s"))
.extend(uart.UART_DEVICE_SCHEMA) .extend(uart.UART_DEVICE_SCHEMA)
) )
FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema(
"sim800l", baud_rate=9600, require_tx=True, require_rx=True
)
async def to_code(config): async def to_code(config):
@ -54,10 +57,6 @@ async def to_code(config):
) )
def validate(config, item_config):
uart.validate_device("sim800l", config, item_config, baud_rate=9600)
SIM800L_SEND_SMS_SCHEMA = cv.Schema( SIM800L_SEND_SMS_SCHEMA = cv.Schema(
{ {
cv.GenerateID(): cv.use_id(Sim800LComponent), cv.GenerateID(): cv.use_id(Sim800LComponent),

View file

@ -1,5 +1,6 @@
import esphome.codegen as cg import esphome.codegen as cg
import esphome.config_validation as cv import esphome.config_validation as cv
import esphome.final_validate as fv
from esphome import pins from esphome import pins
from esphome.const import ( from esphome.const import (
CONF_CLK_PIN, CONF_CLK_PIN,
@ -69,9 +70,24 @@ async def register_spi_device(var, config):
cg.add(var.set_cs_pin(pin)) cg.add(var.set_cs_pin(pin))
def validate_device(name, config, item_config, require_mosi, require_miso): def final_validate_device_schema(name: str, *, require_mosi: bool, require_miso: bool):
spi_config = config.get_config_by_id(item_config[CONF_SPI_ID]) hub_schema = {}
if require_mosi and CONF_MISO_PIN not in spi_config: if require_miso:
raise ValueError(f"Component {name} requires parent spi to declare miso_pin") hub_schema[
if require_miso and CONF_MOSI_PIN not in spi_config: cv.Required(
raise ValueError(f"Component {name} requires parent spi to declare mosi_pin") CONF_MISO_PIN,
msg=f"Component {name} requires this spi bus to declare a miso_pin",
)
] = cv.valid
if require_mosi:
hub_schema[
cv.Required(
CONF_MOSI_PIN,
msg=f"Component {name} requires this spi bus to declare a mosi_pin",
)
] = cv.valid
return cv.Schema(
{cv.Required(CONF_SPI_ID): fv.id_declaration_match_schema(hub_schema)},
extra=cv.ALLOW_EXTRA,
)

View file

@ -1,5 +1,8 @@
from typing import Optional
import esphome.codegen as cg import esphome.codegen as cg
import esphome.config_validation as cv import esphome.config_validation as cv
import esphome.final_validate as fv
from esphome import pins, automation from esphome import pins, automation
from esphome.const import ( from esphome.const import (
CONF_BAUD_RATE, CONF_BAUD_RATE,
@ -92,42 +95,6 @@ async def to_code(config):
cg.add(var.set_parity(config[CONF_PARITY])) cg.add(var.set_parity(config[CONF_PARITY]))
def validate_device(
name, config, item_config, baud_rate=None, require_tx=True, require_rx=True
):
if not hasattr(config, "uart_devices"):
config.uart_devices = {}
devices = config.uart_devices
uart_config = config.get_config_by_id(item_config[CONF_UART_ID])
uart_id = uart_config[CONF_ID]
device = devices.setdefault(uart_id, {})
if require_tx:
if CONF_TX_PIN not in uart_config:
raise ValueError(f"Component {name} requires parent uart to declare tx_pin")
if CONF_TX_PIN in device:
raise ValueError(
f"Component {name} cannot use the same uart.{CONF_TX_PIN} as component {device[CONF_TX_PIN]} is already using it"
)
device[CONF_TX_PIN] = name
if require_rx:
if CONF_RX_PIN not in uart_config:
raise ValueError(f"Component {name} requires parent uart to declare rx_pin")
if CONF_RX_PIN in device:
raise ValueError(
f"Component {name} cannot use the same uart.{CONF_RX_PIN} as component {device[CONF_RX_PIN]} is already using it"
)
device[CONF_RX_PIN] = name
if baud_rate and uart_config[CONF_BAUD_RATE] != baud_rate:
raise ValueError(
f"Component {name} requires parent uart baud rate be {baud_rate}"
)
# A schema to use for all UART devices, all UART integrations must extend this! # A schema to use for all UART devices, all UART integrations must extend this!
UART_DEVICE_SCHEMA = cv.Schema( UART_DEVICE_SCHEMA = cv.Schema(
{ {
@ -135,6 +102,64 @@ UART_DEVICE_SCHEMA = cv.Schema(
} }
) )
KEY_UART_DEVICES = "uart_devices"
def final_validate_device_schema(
name: str,
*,
baud_rate: Optional[int] = None,
require_tx: bool = False,
require_rx: bool = False,
):
def validate_baud_rate(value):
if value != baud_rate:
raise cv.Invalid(
f"Component {name} required baud rate {baud_rate} for the uart bus"
)
return value
def validate_pin(opt, device):
def validator(value):
if opt in device:
raise cv.Invalid(
f"The uart {opt} is used both by {name} and {device[opt]}, "
f"but can only be used by one. Please create a new uart bus for {name}."
)
device[opt] = name
return value
return validator
def validate_hub(hub_config):
hub_schema = {}
uart_id = hub_config[CONF_ID]
devices = fv.full_config.get().data.setdefault(KEY_UART_DEVICES, {})
device = devices.setdefault(uart_id, {})
if require_tx:
hub_schema[
cv.Required(
CONF_TX_PIN,
msg=f"Component {name} requires this uart bus to declare a tx_pin",
)
] = validate_pin(CONF_TX_PIN, device)
if require_rx:
hub_schema[
cv.Required(
CONF_RX_PIN,
msg=f"Component {name} requires this uart bus to declare a rx_pin",
)
] = validate_pin(CONF_RX_PIN, device)
if baud_rate is not None:
hub_schema[cv.Required(CONF_BAUD_RATE)] = validate_baud_rate
return cv.Schema(hub_schema, extra=cv.ALLOW_EXTRA)(hub_config)
return cv.Schema(
{cv.Required(CONF_UART_ID): fv.id_declaration_match_schema(validate_hub)},
extra=cv.ALLOW_EXTRA,
)
async def register_uart_device(var, config): async def register_uart_device(var, config):
"""Register a UART device, setting up all the internal values. """Register a UART device, setting up all the internal values.

View file

@ -1,5 +1,6 @@
import esphome.codegen as cg import esphome.codegen as cg
import esphome.config_validation as cv import esphome.config_validation as cv
import esphome.final_validate as fv
from esphome import automation from esphome import automation
from esphome.automation import Condition from esphome.automation import Condition
from esphome.components.network import add_mdns_library from esphome.components.network import add_mdns_library
@ -137,18 +138,19 @@ WIFI_NETWORK_STA = WIFI_NETWORK_BASE.extend(
) )
def validate(config, item_config): def final_validate(config):
if ( has_sta = bool(config.get(CONF_NETWORKS, True))
(CONF_NETWORKS in item_config) has_ap = CONF_AP in config
and (item_config[CONF_NETWORKS] == []) has_improv = "esp32_improv" in fv.full_config.get()
and (CONF_AP not in item_config) if (not has_sta) and (not has_ap) and (not has_improv):
): raise cv.Invalid(
if "esp32_improv" not in config:
raise ValueError(
"Please specify at least an SSID or an Access Point to create." "Please specify at least an SSID or an Access Point to create."
) )
FINAL_VALIDATE_SCHEMA = cv.Schema(final_validate)
def _validate(config): def _validate(config):
if CONF_PASSWORD in config and CONF_SSID not in config: if CONF_PASSWORD in config and CONF_SSID not in config:
raise cv.Invalid("Cannot have WiFi password without SSID!") raise cv.Invalid("Cannot have WiFi password without SSID!")

View file

@ -21,11 +21,13 @@ from esphome.helpers import indent
from esphome.util import safe_print, OrderedDict from esphome.util import safe_print, OrderedDict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from esphome.core import ConfigType
from esphome.loader import get_component, get_platform, ComponentManifest from esphome.loader import get_component, get_platform, ComponentManifest
from esphome.yaml_util import is_secret, ESPHomeDataBase, ESPForceValue from esphome.yaml_util import is_secret, ESPHomeDataBase, ESPForceValue
from esphome.voluptuous_schema import ExtraKeysInvalid from esphome.voluptuous_schema import ExtraKeysInvalid
from esphome.log import color, Fore from esphome.log import color, Fore
import esphome.final_validate as fv
import esphome.config_validation as cv
from esphome.types import ConfigType, ConfigPathType, ConfigFragmentType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -54,7 +56,7 @@ def _path_begins_with(path, other): # type: (ConfigPath, ConfigPath) -> bool
return path[: len(other)] == other return path[: len(other)] == other
class Config(OrderedDict): class Config(OrderedDict, fv.FinalValidateConfig):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# A list of voluptuous errors # A list of voluptuous errors
@ -65,6 +67,7 @@ class Config(OrderedDict):
self.output_paths = [] # type: List[Tuple[ConfigPath, str]] self.output_paths = [] # type: List[Tuple[ConfigPath, str]]
# A list of components ids with the config path # A list of components ids with the config path
self.declare_ids = [] # type: List[Tuple[core.ID, ConfigPath]] self.declare_ids = [] # type: List[Tuple[core.ID, ConfigPath]]
self._data = {}
def add_error(self, error): def add_error(self, error):
# type: (vol.Invalid) -> None # type: (vol.Invalid) -> None
@ -72,6 +75,12 @@ class Config(OrderedDict):
for err in error.errors: for err in error.errors:
self.add_error(err) self.add_error(err)
return return
if cv.ROOT_CONFIG_PATH in error.path:
# Root value means that the path before the root should be ignored
last_root = max(
i for i, v in enumerate(error.path) if v is cv.ROOT_CONFIG_PATH
)
error.path = error.path[last_root + 1 :]
self.errors.append(error) self.errors.append(error)
@contextmanager @contextmanager
@ -140,13 +149,16 @@ class Config(OrderedDict):
return doc_range return doc_range
def get_nested_item(self, path): def get_nested_item(
# type: (ConfigPath) -> ConfigType self, path: ConfigPathType, raise_error: bool = False
) -> ConfigFragmentType:
data = self data = self
for item_index in path: for item_index in path:
try: try:
data = data[item_index] data = data[item_index]
except (KeyError, IndexError, TypeError): except (KeyError, IndexError, TypeError):
if raise_error:
raise
return {} return {}
return data return data
@ -163,11 +175,20 @@ class Config(OrderedDict):
part.append(item_index) part.append(item_index)
return part return part
def get_config_by_id(self, id): def get_path_for_id(self, id: core.ID):
"""Return the config fragment where the given ID is declared."""
for declared_id, path in self.declare_ids: for declared_id, path in self.declare_ids:
if declared_id.id == str(id): if declared_id.id == str(id):
return self.get_nested_item(path[:-1]) return path
return None raise KeyError(f"ID {id} not found in configuration")
def get_config_for_path(self, path: ConfigPathType) -> ConfigFragmentType:
return self.get_nested_item(path, raise_error=True)
@property
def data(self):
"""Return temporary data used by final validation functions."""
return self._data
def iter_ids(config, path=None): def iter_ids(config, path=None):
@ -189,23 +210,22 @@ def do_id_pass(result): # type: (Config) -> None
from esphome.cpp_generator import MockObjClass from esphome.cpp_generator import MockObjClass
from esphome.cpp_types import Component from esphome.cpp_types import Component
declare_ids = result.declare_ids # type: List[Tuple[core.ID, ConfigPath]]
searching_ids = [] # type: List[Tuple[core.ID, ConfigPath]] searching_ids = [] # type: List[Tuple[core.ID, ConfigPath]]
for id, path in iter_ids(result): for id, path in iter_ids(result):
if id.is_declaration: if id.is_declaration:
if id.id is not None: if id.id is not None:
# Look for duplicate definitions # Look for duplicate definitions
match = next((v for v in declare_ids if v[0].id == id.id), None) match = next((v for v in result.declare_ids if v[0].id == id.id), None)
if match is not None: if match is not None:
opath = "->".join(str(v) for v in match[1]) opath = "->".join(str(v) for v in match[1])
result.add_str_error(f"ID {id.id} redefined! Check {opath}", path) result.add_str_error(f"ID {id.id} redefined! Check {opath}", path)
continue continue
declare_ids.append((id, path)) result.declare_ids.append((id, path))
else: else:
searching_ids.append((id, path)) searching_ids.append((id, path))
# Resolve default ids after manual IDs # Resolve default ids after manual IDs
for id, _ in declare_ids: for id, _ in result.declare_ids:
id.resolve([v[0].id for v in declare_ids]) id.resolve([v[0].id for v in result.declare_ids])
if isinstance(id.type, MockObjClass) and id.type.inherits_from(Component): if isinstance(id.type, MockObjClass) and id.type.inherits_from(Component):
CORE.component_ids.add(id.id) CORE.component_ids.add(id.id)
@ -213,7 +233,7 @@ def do_id_pass(result): # type: (Config) -> None
for id, path in searching_ids: for id, path in searching_ids:
if id.id is not None: if id.id is not None:
# manually declared # manually declared
match = next((v[0] for v in declare_ids if v[0].id == id.id), None) match = next((v[0] for v in result.declare_ids if v[0].id == id.id), None)
if match is None or not match.is_manual: if match is None or not match.is_manual:
# No declared ID with this name # No declared ID with this name
import difflib import difflib
@ -224,7 +244,7 @@ def do_id_pass(result): # type: (Config) -> None
) )
# Find candidates # Find candidates
matches = difflib.get_close_matches( matches = difflib.get_close_matches(
id.id, [v[0].id for v in declare_ids if v[0].is_manual] id.id, [v[0].id for v in result.declare_ids if v[0].is_manual]
) )
if matches: if matches:
matches_s = ", ".join(f'"{x}"' for x in matches) matches_s = ", ".join(f'"{x}"' for x in matches)
@ -245,7 +265,7 @@ def do_id_pass(result): # type: (Config) -> None
if id.id is None and id.type is not None: if id.id is None and id.type is not None:
matches = [] matches = []
for v in declare_ids: for v in result.declare_ids:
if v[0] is None or not isinstance(v[0].type, MockObjClass): if v[0] is None or not isinstance(v[0].type, MockObjClass):
continue continue
inherits = v[0].type.inherits_from(id.type) inherits = v[0].type.inherits_from(id.type)
@ -278,8 +298,6 @@ def do_id_pass(result): # type: (Config) -> None
def recursive_check_replaceme(value): def recursive_check_replaceme(value):
import esphome.config_validation as cv
if isinstance(value, list): if isinstance(value, list):
return cv.Schema([recursive_check_replaceme])(value) return cv.Schema([recursive_check_replaceme])(value)
if isinstance(value, dict): if isinstance(value, dict):
@ -558,14 +576,16 @@ def validate_config(config, command_line_substitutions):
# 7. Final validation # 7. Final validation
if not result.errors: if not result.errors:
# Inter - components validation # Inter - components validation
for path, conf, comp in validate_queue: token = fv.full_config.set(result)
if comp.config_schema is None:
for path, _, comp in validate_queue:
if comp.final_validate_schema is None:
continue continue
if callable(comp.validate): conf = result.get_nested_item(path)
try: with result.catch_error(path):
comp.validate(result, result.get_nested_item(path)) comp.final_validate_schema(conf)
except ValueError as err:
result.add_str_error(err, path) fv.full_config.reset(token)
return result return result
@ -621,8 +641,12 @@ def _format_vol_invalid(ex, config):
) )
elif "extra keys not allowed" in str(ex): elif "extra keys not allowed" in str(ex):
message += "[{}] is an invalid option for [{}].".format(ex.path[-1], paren) message += "[{}] is an invalid option for [{}].".format(ex.path[-1], paren)
elif "required key not provided" in str(ex): elif isinstance(ex, vol.RequiredFieldInvalid):
if ex.msg == "required key not provided":
message += "'{}' is a required option for [{}].".format(ex.path[-1], paren) message += "'{}' is a required option for [{}].".format(ex.path[-1], paren)
else:
# Required has set a custom error message
message += ex.msg
else: else:
message += humanize_error(config, ex) message += humanize_error(config, ex)

View file

@ -75,6 +75,9 @@ Inclusive = vol.Inclusive
ALLOW_EXTRA = vol.ALLOW_EXTRA ALLOW_EXTRA = vol.ALLOW_EXTRA
UNDEFINED = vol.UNDEFINED UNDEFINED = vol.UNDEFINED
RequiredFieldInvalid = vol.RequiredFieldInvalid RequiredFieldInvalid = vol.RequiredFieldInvalid
# this sentinel object can be placed in an 'Invalid' path to say
# the rest of the error path is relative to the root config path
ROOT_CONFIG_PATH = object()
RESERVED_IDS = [ RESERVED_IDS = [
# C++ keywords http://en.cppreference.com/w/cpp/keyword # C++ keywords http://en.cppreference.com/w/cpp/keyword
@ -218,8 +221,8 @@ class Required(vol.Required):
- *not* the `config.get(CONF_<KEY>)` syntax. - *not* the `config.get(CONF_<KEY>)` syntax.
""" """
def __init__(self, key): def __init__(self, key, msg=None):
super().__init__(key) super().__init__(key, msg=msg)
def check_not_templatable(value): def check_not_templatable(value):
@ -1073,6 +1076,7 @@ def invalid(message):
def valid(value): def valid(value):
"""A validator that is always valid and returns the value as-is."""
return value return value

View file

@ -2,7 +2,7 @@ import logging
import math import math
import os import os
import re import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from esphome.const import ( from esphome.const import (
CONF_ARDUINO_VERSION, CONF_ARDUINO_VERSION,
@ -23,6 +23,7 @@ from esphome.util import OrderedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from ..cpp_generator import MockObj, MockObjClass, Statement from ..cpp_generator import MockObj, MockObjClass, Statement
from ..types import ConfigType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -462,9 +463,9 @@ class EsphomeCore:
# The board that's used (for example nodemcuv2) # The board that's used (for example nodemcuv2)
self.board: Optional[str] = None self.board: Optional[str] = None
# The full raw configuration # The full raw configuration
self.raw_config: Optional[ConfigType] = None self.raw_config: Optional["ConfigType"] = None
# The validated configuration, this is None until the config has been validated # The validated configuration, this is None until the config has been validated
self.config: Optional[ConfigType] = None self.config: Optional["ConfigType"] = None
# The pending tasks in the task queue (mostly for C++ generation) # The pending tasks in the task queue (mostly for C++ generation)
# This is a priority queue (with heapq) # This is a priority queue (with heapq)
# Each item is a tuple of form: (-priority, unique number, task) # Each item is a tuple of form: (-priority, unique number, task)
@ -752,6 +753,3 @@ class EnumValue:
CORE = EsphomeCore() CORE = EsphomeCore()
ConfigType = Dict[str, Any]
CoreType = EsphomeCore

View file

@ -8,7 +8,8 @@ from esphome.const import (
) )
# pylint: disable=unused-import # pylint: disable=unused-import
from esphome.core import coroutine, ID, CORE, ConfigType from esphome.core import coroutine, ID, CORE
from esphome.types import ConfigType
from esphome.cpp_generator import RawExpression, add, get_variable from esphome.cpp_generator import RawExpression, add, get_variable
from esphome.cpp_types import App, GPIOPin from esphome.cpp_types import App, GPIOPin
from esphome.util import Registry, RegistryEntry from esphome.util import Registry, RegistryEntry

57
esphome/final_validate.py Normal file
View file

@ -0,0 +1,57 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, Any
import contextvars
from esphome.types import ConfigFragmentType, ID, ConfigPathType
import esphome.config_validation as cv
class FinalValidateConfig(ABC):
@abstractproperty
def data(self) -> Dict[str, Any]:
"""A dictionary that can be used by post validation functions to store
global data during the validation phase. Each component should store its
data under a unique key
"""
@abstractmethod
def get_path_for_id(self, id: ID) -> ConfigPathType:
"""Get the config path a given ID has been declared in.
This is the location under the _validated_ config (for example, with cv.ensure_list applied)
Raises KeyError if the id was not declared in the configuration.
"""
@abstractmethod
def get_config_for_path(self, path: ConfigPathType) -> ConfigFragmentType:
"""Get the config fragment for the given global path.
Raises KeyError if a key in the path does not exist.
"""
FinalValidateConfig.register(dict)
# Context variable tracking the full config for some final validation functions.
full_config: contextvars.ContextVar[FinalValidateConfig] = contextvars.ContextVar(
"full_config"
)
def id_declaration_match_schema(schema):
"""A final-validation schema function that applies a schema to the outer config fragment of an
ID declaration.
This validator must be applied to ID values.
"""
if not isinstance(schema, cv.Schema):
schema = cv.Schema(schema, extra=cv.ALLOW_EXTRA)
def validator(value):
fconf = full_config.get()
path = fconf.get_path_for_id(value)[:-1]
declaration_config = fconf.get_config_for_path(path)
with cv.prepend_path([cv.ROOT_CONFIG_PATH] + path):
return schema(declaration_config)
return validator

View file

@ -12,6 +12,7 @@ from pathlib import Path
from esphome.const import ESP_PLATFORMS, SOURCE_FILE_EXTENSIONS from esphome.const import ESP_PLATFORMS, SOURCE_FILE_EXTENSIONS
import esphome.core.config import esphome.core.config
from esphome.core import CORE from esphome.core import CORE
from esphome.types import ConfigType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -81,8 +82,13 @@ class ComponentManifest:
return getattr(self.module, "CODEOWNERS", []) return getattr(self.module, "CODEOWNERS", [])
@property @property
def validate(self): def final_validate_schema(self) -> Optional[Callable[[ConfigType], None]]:
return getattr(self.module, "validate", None) """Components can declare a `FINAL_VALIDATE_SCHEMA` cv.Schema that gets called
after the main validation. In that function checks across components can be made.
Note that the function can't mutate the configuration - no changes are saved
"""
return getattr(self.module, "FINAL_VALIDATE_SCHEMA", None)
@property @property
def source_files(self) -> Dict[Path, SourceFile]: def source_files(self) -> Dict[Path, SourceFile]:

View file

@ -4,14 +4,13 @@ from datetime import datetime
import json import json
import logging import logging
import os import os
from typing import Any, Optional, List
from esphome import const from esphome import const
from esphome.core import CORE from esphome.core import CORE
from esphome.helpers import write_file_if_changed from esphome.helpers import write_file_if_changed
# pylint: disable=unused-import, wrong-import-order from esphome.types import CoreType
from esphome.core import CoreType
from typing import Any, Optional, List
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

18
esphome/types.py Normal file
View file

@ -0,0 +1,18 @@
"""This helper module tracks commonly used types in the esphome python codebase."""
from typing import Dict, Union, List
from esphome.core import ID, Lambda, EsphomeCore
ConfigFragmentType = Union[
str,
int,
float,
None,
Dict[Union[str, int], "ConfigFragmentType"],
List["ConfigFragmentType"],
ID,
Lambda,
]
ConfigType = Dict[str, ConfigFragmentType]
CoreType = EsphomeCore
ConfigPathType = Union[str, int]