From 214b419db25a6b71b121b75ea0dc1ceab9bdee78 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 14 Nov 2023 20:21:44 -0600 Subject: [PATCH] dashboard: Use mdns cache when available if device connection is OTA (#5724) * Use mdns or freshen cache when device connection is OTA Since we already have a service browser running, we likely already know the IP of the deivce we want to connect to so we can replace OTA with the address to avoid the esphome app having to look it up again * isort * Fix zeroconf name resolution refactoring error HostResolver should get the type as the first arg instead of the name * no i/o * tornado support native coros * lint * use new tornado start methods * use new tornado start methods * use new tornado start methods * break * lint * lint * typing, missing awaits * io in executor * missed one * fix: missing if * stale comment * rename run_command to build_device_command since it does not actually run anything --- esphome/dashboard/async_adapter.py | 56 +++++++ esphome/dashboard/dashboard.py | 231 +++++++++++++++++++---------- esphome/zeroconf.py | 60 ++++++-- 3 files changed, 255 insertions(+), 92 deletions(-) create mode 100644 esphome/dashboard/async_adapter.py diff --git a/esphome/dashboard/async_adapter.py b/esphome/dashboard/async_adapter.py new file mode 100644 index 0000000000..d6f4f6e1ff --- /dev/null +++ b/esphome/dashboard/async_adapter.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import asyncio +import threading + + +class ThreadedAsyncEvent: + """This is a shim to allow the asyncio event to be used in a threaded context. + + When more of the code is moved to asyncio, this can be removed. + """ + + def __init__(self) -> None: + """Initialize the ThreadedAsyncEvent.""" + self.event = threading.Event() + self.async_event: asyncio.Event | None = None + self.loop: asyncio.AbstractEventLoop | None = None + + def async_setup( + self, loop: asyncio.AbstractEventLoop, async_event: asyncio.Event + ) -> None: + """Set the asyncio.Event instance.""" + self.loop = loop + self.async_event = async_event + + def async_set(self) -> None: + """Set the asyncio.Event instance.""" + self.async_event.set() + self.event.set() + + def set(self) -> None: + """Set the event.""" + self.loop.call_soon_threadsafe(self.async_event.set) + self.event.set() + + def wait(self) -> None: + """Wait for the event.""" + self.event.wait() + + async def async_wait(self) -> None: + """Wait the event async.""" + await self.async_event.wait() + + def clear(self) -> None: + """Clear the event.""" + self.loop.call_soon_threadsafe(self.async_event.clear) + self.event.clear() + + def async_clear(self) -> None: + """Clear the event async.""" + self.async_event.clear() + self.event.clear() + + def is_set(self) -> bool: + """Return if the event is set.""" + return self.event.is_set() diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 050564d21e..d7d11d8693 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -18,6 +18,7 @@ import shutil import subprocess import threading from pathlib import Path +from typing import Any import tornado import tornado.concurrent @@ -46,11 +47,12 @@ from esphome.storage_json import ( from esphome.util import get_serial_ports, shlex_quote from esphome.zeroconf import ( ESPHOME_SERVICE_TYPE, + AsyncEsphomeZeroconf, DashboardBrowser, DashboardImportDiscovery, DashboardStatus, - EsphomeZeroconf, ) +from .async_adapter import ThreadedAsyncEvent from .util import friendly_name_slugify, password_hash @@ -288,7 +290,10 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._use_popen = os.name == "nt" @authenticated - def on_message(self, message): + async def on_message( # pylint: disable=invalid-overridden-method + self, message: str + ) -> None: + # Since tornado 4.5, on_message is allowed to be a coroutine # Messages are always JSON, 500 when not json_message = json.loads(message) type_ = json_message["type"] @@ -298,14 +303,14 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): _LOGGER.warning("Requested unknown message type %s", type_) return - handlers[type_](self, json_message) + await handlers[type_](self, json_message) @websocket_method("spawn") - def handle_spawn(self, json_message): + async def handle_spawn(self, json_message: dict[str, Any]) -> None: if self._proc is not None: # spawn can only be called once return - command = self.build_command(json_message) + command = await self.build_command(json_message) _LOGGER.info("Running command '%s'", " ".join(shlex_quote(x) for x in command)) if self._use_popen: @@ -336,7 +341,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): return self._proc is not None and self._proc.returncode is None @websocket_method("stdin") - def handle_stdin(self, json_message): + async def handle_stdin(self, json_message: dict[str, Any]) -> None: if not self.is_process_active: return text: str = json_message["data"] @@ -345,7 +350,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._proc.stdin.write(data) @tornado.gen.coroutine - def _redirect_stdout(self): + def _redirect_stdout(self) -> None: reg = b"[\n\r]" while True: @@ -364,7 +369,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): _LOGGER.debug("> stdout: %s", text) self.write_message({"event": "line", "data": text}) - def _stdout_thread(self): + def _stdout_thread(self) -> None: if not self._use_popen: return while True: @@ -377,13 +382,13 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._proc.wait(1.0) self._queue.put_nowait(None) - def _proc_on_exit(self, returncode): + def _proc_on_exit(self, returncode: int) -> None: if not self._is_closed: # Check if the proc was not forcibly closed _LOGGER.info("Process exited with return code %s", returncode) self.write_message({"event": "exit", "code": returncode}) - def on_close(self): + def on_close(self) -> None: # Check if proc exists (if 'start' has been run) if self.is_process_active: _LOGGER.debug("Terminating process") @@ -394,32 +399,54 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): # Shutdown proc on WS close self._is_closed = True - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: raise NotImplementedError -class EsphomeLogsHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - config_file = settings.rel_path(json_message["configuration"]) +DASHBOARD_COMMAND = ["esphome", "--dashboard"] + + +class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): + """Base class for commands that require a port.""" + + async def build_device_command( + self, args: list[str], json_message: dict[str, Any] + ) -> list[str]: + """Build the command to run.""" + configuration = json_message["configuration"] + config_file = settings.rel_path(configuration) + port = json_message["port"] + if ( + port == "OTA" + and (mdns := MDNS_CONTAINER.get_mdns()) + and (host_name := mdns.filename_to_host_name_thread_safe(configuration)) + and (address := await mdns.async_resolve_host(host_name)) + ): + port = address + return [ - "esphome", - "--dashboard", - "logs", + *DASHBOARD_COMMAND, + *args, config_file, "--device", - json_message["port"], + port, ] +class EsphomeLogsHandler(EsphomePortCommandWebSocket): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + """Build the command to run.""" + return await self.build_device_command(["logs"], json_message) + + class EsphomeRenameHandler(EsphomeCommandWebSocket): old_name: str - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) self.old_name = json_message["configuration"] return [ - "esphome", - "--dashboard", + *DASHBOARD_COMMAND, "rename", config_file, json_message["newName"], @@ -435,36 +462,22 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket): PING_RESULT.pop(self.old_name, None) -class EsphomeUploadHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - config_file = settings.rel_path(json_message["configuration"]) - return [ - "esphome", - "--dashboard", - "upload", - config_file, - "--device", - json_message["port"], - ] +class EsphomeUploadHandler(EsphomePortCommandWebSocket): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + """Build the command to run.""" + return await self.build_device_command(["upload"], json_message) -class EsphomeRunHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - config_file = settings.rel_path(json_message["configuration"]) - return [ - "esphome", - "--dashboard", - "run", - config_file, - "--device", - json_message["port"], - ] +class EsphomeRunHandler(EsphomePortCommandWebSocket): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + """Build the command to run.""" + return await self.build_device_command(["run"], json_message) class EsphomeCompileHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) - command = ["esphome", "--dashboard", "compile"] + command = [*DASHBOARD_COMMAND, "compile"] if json_message.get("only_generate", False): command.append("--only-generate") command.append(config_file) @@ -472,39 +485,39 @@ class EsphomeCompileHandler(EsphomeCommandWebSocket): class EsphomeValidateHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) - command = ["esphome", "--dashboard", "config", config_file] + command = [*DASHBOARD_COMMAND, "config", config_file] if not settings.streamer_mode: command.append("--show-secrets") return command class EsphomeCleanMqttHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) - return ["esphome", "--dashboard", "clean-mqtt", config_file] + return [*DASHBOARD_COMMAND, "clean-mqtt", config_file] class EsphomeCleanHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) - return ["esphome", "--dashboard", "clean", config_file] + return [*DASHBOARD_COMMAND, "clean", config_file] class EsphomeVscodeHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - return ["esphome", "--dashboard", "-q", "vscode", "dummy"] + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + return [*DASHBOARD_COMMAND, "-q", "vscode", "dummy"] class EsphomeAceEditorHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - return ["esphome", "--dashboard", "-q", "vscode", "--ace", settings.config_dir] + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + return [*DASHBOARD_COMMAND, "-q", "vscode", "--ace", settings.config_dir] class EsphomeUpdateAllHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - return ["esphome", "--dashboard", "update-all", settings.config_dir] + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + return [*DASHBOARD_COMMAND, "update-all", settings.config_dir] class SerialPortRequestHandler(BaseHandler): @@ -838,8 +851,9 @@ class DashboardEntry: class ListDevicesHandler(BaseHandler): @authenticated - def get(self): - entries = _list_dashboard_entries() + async def get(self): + loop = asyncio.get_running_loop() + entries = await loop.run_in_executor(None, _list_dashboard_entries) self.set_header("content-type", "application/json") configured = {entry.name for entry in entries} self.write( @@ -963,24 +977,39 @@ class BoardsRequestHandler(BaseHandler): self.write(json.dumps(output)) -class MDNSStatusThread(threading.Thread): - def __init__(self): - """Initialize the MDNSStatusThread.""" +class MDNSStatus: + """Class that updates the mdns status.""" + + def __init__(self) -> None: + """Initialize the MDNSStatus class.""" super().__init__() + self.aiozc: AsyncEsphomeZeroconf | None = None # This is the current mdns state for each host (True, False, None) self.host_mdns_state: dict[str, bool | None] = {} # This is the hostnames to filenames mapping self.host_name_to_filename: dict[str, str] = {} + self.filename_to_host_name: dict[str, str] = {} # This is a set of host names to track (i.e no_mdns = false) self.host_name_with_mdns_enabled: set[set] = set() - self._refresh_hosts() + self._loop = asyncio.get_running_loop() - def _refresh_hosts(self): + def filename_to_host_name_thread_safe(self, filename: str) -> str | None: + """Resolve a filename to an address in a thread-safe manner.""" + return self.filename_to_host_name.get(filename) + + async def async_resolve_host(self, host_name: str) -> str | None: + """Resolve a host name to an address in a thread-safe manner.""" + if aiozc := self.aiozc: + return await aiozc.async_resolve_host(host_name) + return None + + async def async_refresh_hosts(self): """Refresh the hosts to track.""" - entries = _list_dashboard_entries() + entries = await self._loop.run_in_executor(None, _list_dashboard_entries) host_name_with_mdns_enabled = self.host_name_with_mdns_enabled host_mdns_state = self.host_mdns_state host_name_to_filename = self.host_name_to_filename + filename_to_host_name = self.filename_to_host_name for entry in entries: name = entry.name @@ -1003,11 +1032,13 @@ class MDNSStatusThread(threading.Thread): # so when we get an mdns update we can map it back # to the filename host_name_to_filename[name] = filename + filename_to_host_name[filename] = name - def run(self): + async def async_run(self) -> None: global IMPORT_RESULT - zc = EsphomeZeroconf() + aiozc = AsyncEsphomeZeroconf() + self.aiozc = aiozc host_mdns_state = self.host_mdns_state host_name_to_filename = self.host_name_to_filename host_name_with_mdns_enabled = self.host_name_with_mdns_enabled @@ -1020,21 +1051,23 @@ class MDNSStatusThread(threading.Thread): filename = host_name_to_filename[name] PING_RESULT[filename] = result - self._refresh_hosts() stat = DashboardStatus(on_update) imports = DashboardImportDiscovery() browser = DashboardBrowser( - zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback] + aiozc.zeroconf, + ESPHOME_SERVICE_TYPE, + [stat.browser_callback, imports.browser_callback], ) while not STOP_EVENT.is_set(): - self._refresh_hosts() + await self.async_refresh_hosts() IMPORT_RESULT = imports.import_state - PING_REQUEST.wait() - PING_REQUEST.clear() + await PING_REQUEST.async_wait() + PING_REQUEST.async_clear() - browser.cancel() - zc.close() + await browser.async_cancel() + await aiozc.async_close() + self.aiozc = None class PingStatusThread(threading.Thread): @@ -1211,11 +1244,26 @@ class UndoDeleteRequestHandler(BaseHandler): shutil.move(os.path.join(trash_path, configuration), config_file) +class MDNSContainer: + def __init__(self) -> None: + """Initialize the MDNSContainer.""" + self._mdns: MDNSStatus | None = None + + def set_mdns(self, mdns: MDNSStatus) -> None: + """Set the MDNSStatus instance.""" + self._mdns = mdns + + def get_mdns(self) -> MDNSStatus | None: + """Return the MDNSStatus instance.""" + return self._mdns + + PING_RESULT: dict = {} IMPORT_RESULT = {} STOP_EVENT = threading.Event() -PING_REQUEST = threading.Event() +PING_REQUEST = ThreadedAsyncEvent() MQTT_PING_REQUEST = threading.Event() +MDNS_CONTAINER = MDNSContainer() class LoginHandler(BaseHandler): @@ -1478,6 +1526,16 @@ def start_web_server(args): storage.save(path) settings.cookie_secret = storage.cookie_secret + try: + asyncio.run(async_start_web_server(args)) + except KeyboardInterrupt: + pass + + +async def async_start_web_server(args): + loop = asyncio.get_event_loop() + PING_REQUEST.async_setup(loop, asyncio.Event()) + app = make_app(args.verbose) if args.socket is not None: _LOGGER.info( @@ -1502,25 +1560,36 @@ def start_web_server(args): webbrowser.open(f"http://{args.address}:{args.port}") + mdns_task: asyncio.Task | None = None + ping_status_thread: PingStatusThread | None = None if settings.status_use_ping: - status_thread = PingStatusThread() + ping_status_thread = PingStatusThread() + ping_status_thread.start() else: - status_thread = MDNSStatusThread() - status_thread.start() + mdns_status = MDNSStatus() + await mdns_status.async_refresh_hosts() + MDNS_CONTAINER.set_mdns(mdns_status) + mdns_task = asyncio.create_task(mdns_status.async_run()) if settings.status_use_mqtt: status_thread_mqtt = MqttStatusThread() status_thread_mqtt.start() + shutdown_event = asyncio.Event() try: - tornado.ioloop.IOLoop.current().start() - except KeyboardInterrupt: + await shutdown_event.wait() + finally: _LOGGER.info("Shutting down...") STOP_EVENT.set() PING_REQUEST.set() - status_thread.join() + if ping_status_thread: + ping_status_thread.join() + MDNS_CONTAINER.set_mdns(None) + if mdns_task: + mdns_task.cancel() if settings.status_use_mqtt: status_thread_mqtt.join() MQTT_PING_REQUEST.set() if args.socket is not None: os.remove(args.socket) + await asyncio.sleep(0) diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index f4cb7f080b..956e348e07 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -1,22 +1,21 @@ from __future__ import annotations +import asyncio import logging from dataclasses import dataclass from typing import Callable -from zeroconf import ( - IPVersion, - ServiceBrowser, - ServiceInfo, - ServiceStateChange, - Zeroconf, -) +from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf from esphome.storage_json import StorageJSON, ext_storage_path _LOGGER = logging.getLogger(__name__) +_BACKGROUND_TASKS: set[asyncio.Task] = set() + + class HostResolver(ServiceInfo): """Resolve a host name to an IP address.""" @@ -65,7 +64,7 @@ class DiscoveredImport: network: str -class DashboardBrowser(ServiceBrowser): +class DashboardBrowser(AsyncServiceBrowser): """A class to browse for ESPHome nodes.""" @@ -94,7 +93,28 @@ class DashboardImportDiscovery: # Ignore updates for devices that are not in the import state return - info = zeroconf.get_service_info(service_type, name) + info = AsyncServiceInfo( + service_type, + name, + ) + if info.load_from_cache(zeroconf): + self._process_service_info(name, info) + return + task = asyncio.create_task( + self._async_process_service_info(zeroconf, info, service_type, name) + ) + _BACKGROUND_TASKS.add(task) + task.add_done_callback(_BACKGROUND_TASKS.discard) + + async def _async_process_service_info( + self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str + ) -> None: + """Process a service info.""" + if await info.async_request(zeroconf): + self._process_service_info(name, info) + + def _process_service_info(self, name: str, info: ServiceInfo) -> None: + """Process a service info.""" _LOGGER.debug("-> resolved info: %s", info) if info is None: return @@ -146,14 +166,32 @@ class DashboardImportDiscovery: ) +def _make_host_resolver(host: str) -> HostResolver: + """Create a new HostResolver for the given host name.""" + name = host.partition(".")[0] + info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") + return info + + class EsphomeZeroconf(Zeroconf): def resolve_host(self, host: str, timeout: float = 3.0) -> str | None: """Resolve a host name to an IP address.""" - name = host.partition(".")[0] - info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") + info = _make_host_resolver(host) if ( info.load_from_cache(self) or (timeout and info.request(self, timeout * 1000)) ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): return str(addresses[0]) return None + + +class AsyncEsphomeZeroconf(AsyncZeroconf): + async def async_resolve_host(self, host: str, timeout: float = 3.0) -> str | None: + """Resolve a host name to an IP address.""" + info = _make_host_resolver(host) + if ( + info.load_from_cache(self.zeroconf) + or (timeout and await info.async_request(self.zeroconf, timeout * 1000)) + ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): + return str(addresses[0]) + return None