diff --git a/esphome/dashboard/async_adapter.py b/esphome/dashboard/async_adapter.py new file mode 100644 index 0000000000..d6f4f6e1ff --- /dev/null +++ b/esphome/dashboard/async_adapter.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import asyncio +import threading + + +class ThreadedAsyncEvent: + """This is a shim to allow the asyncio event to be used in a threaded context. + + When more of the code is moved to asyncio, this can be removed. + """ + + def __init__(self) -> None: + """Initialize the ThreadedAsyncEvent.""" + self.event = threading.Event() + self.async_event: asyncio.Event | None = None + self.loop: asyncio.AbstractEventLoop | None = None + + def async_setup( + self, loop: asyncio.AbstractEventLoop, async_event: asyncio.Event + ) -> None: + """Set the asyncio.Event instance.""" + self.loop = loop + self.async_event = async_event + + def async_set(self) -> None: + """Set the asyncio.Event instance.""" + self.async_event.set() + self.event.set() + + def set(self) -> None: + """Set the event.""" + self.loop.call_soon_threadsafe(self.async_event.set) + self.event.set() + + def wait(self) -> None: + """Wait for the event.""" + self.event.wait() + + async def async_wait(self) -> None: + """Wait the event async.""" + await self.async_event.wait() + + def clear(self) -> None: + """Clear the event.""" + self.loop.call_soon_threadsafe(self.async_event.clear) + self.event.clear() + + def async_clear(self) -> None: + """Clear the event async.""" + self.async_event.clear() + self.event.clear() + + def is_set(self) -> bool: + """Return if the event is set.""" + return self.event.is_set() diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 050564d21e..d7d11d8693 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -18,6 +18,7 @@ import shutil import subprocess import threading from pathlib import Path +from typing import Any import tornado import tornado.concurrent @@ -46,11 +47,12 @@ from esphome.storage_json import ( from esphome.util import get_serial_ports, shlex_quote from esphome.zeroconf import ( ESPHOME_SERVICE_TYPE, + AsyncEsphomeZeroconf, DashboardBrowser, DashboardImportDiscovery, DashboardStatus, - EsphomeZeroconf, ) +from .async_adapter import ThreadedAsyncEvent from .util import friendly_name_slugify, password_hash @@ -288,7 +290,10 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._use_popen = os.name == "nt" @authenticated - def on_message(self, message): + async def on_message( # pylint: disable=invalid-overridden-method + self, message: str + ) -> None: + # Since tornado 4.5, on_message is allowed to be a coroutine # Messages are always JSON, 500 when not json_message = json.loads(message) type_ = json_message["type"] @@ -298,14 +303,14 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): _LOGGER.warning("Requested unknown message type %s", type_) return - handlers[type_](self, json_message) + await handlers[type_](self, json_message) @websocket_method("spawn") - def handle_spawn(self, json_message): + async def handle_spawn(self, json_message: dict[str, Any]) -> None: if self._proc is not None: # spawn can only be called once return - command = self.build_command(json_message) + command = await self.build_command(json_message) _LOGGER.info("Running command '%s'", " ".join(shlex_quote(x) for x in command)) if self._use_popen: @@ -336,7 +341,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): return self._proc is not None and self._proc.returncode is None @websocket_method("stdin") - def handle_stdin(self, json_message): + async def handle_stdin(self, json_message: dict[str, Any]) -> None: if not self.is_process_active: return text: str = json_message["data"] @@ -345,7 +350,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._proc.stdin.write(data) @tornado.gen.coroutine - def _redirect_stdout(self): + def _redirect_stdout(self) -> None: reg = b"[\n\r]" while True: @@ -364,7 +369,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): _LOGGER.debug("> stdout: %s", text) self.write_message({"event": "line", "data": text}) - def _stdout_thread(self): + def _stdout_thread(self) -> None: if not self._use_popen: return while True: @@ -377,13 +382,13 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._proc.wait(1.0) self._queue.put_nowait(None) - def _proc_on_exit(self, returncode): + def _proc_on_exit(self, returncode: int) -> None: if not self._is_closed: # Check if the proc was not forcibly closed _LOGGER.info("Process exited with return code %s", returncode) self.write_message({"event": "exit", "code": returncode}) - def on_close(self): + def on_close(self) -> None: # Check if proc exists (if 'start' has been run) if self.is_process_active: _LOGGER.debug("Terminating process") @@ -394,32 +399,54 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): # Shutdown proc on WS close self._is_closed = True - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: raise NotImplementedError -class EsphomeLogsHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - config_file = settings.rel_path(json_message["configuration"]) +DASHBOARD_COMMAND = ["esphome", "--dashboard"] + + +class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): + """Base class for commands that require a port.""" + + async def build_device_command( + self, args: list[str], json_message: dict[str, Any] + ) -> list[str]: + """Build the command to run.""" + configuration = json_message["configuration"] + config_file = settings.rel_path(configuration) + port = json_message["port"] + if ( + port == "OTA" + and (mdns := MDNS_CONTAINER.get_mdns()) + and (host_name := mdns.filename_to_host_name_thread_safe(configuration)) + and (address := await mdns.async_resolve_host(host_name)) + ): + port = address + return [ - "esphome", - "--dashboard", - "logs", + *DASHBOARD_COMMAND, + *args, config_file, "--device", - json_message["port"], + port, ] +class EsphomeLogsHandler(EsphomePortCommandWebSocket): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + """Build the command to run.""" + return await self.build_device_command(["logs"], json_message) + + class EsphomeRenameHandler(EsphomeCommandWebSocket): old_name: str - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) self.old_name = json_message["configuration"] return [ - "esphome", - "--dashboard", + *DASHBOARD_COMMAND, "rename", config_file, json_message["newName"], @@ -435,36 +462,22 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket): PING_RESULT.pop(self.old_name, None) -class EsphomeUploadHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - config_file = settings.rel_path(json_message["configuration"]) - return [ - "esphome", - "--dashboard", - "upload", - config_file, - "--device", - json_message["port"], - ] +class EsphomeUploadHandler(EsphomePortCommandWebSocket): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + """Build the command to run.""" + return await self.build_device_command(["upload"], json_message) -class EsphomeRunHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - config_file = settings.rel_path(json_message["configuration"]) - return [ - "esphome", - "--dashboard", - "run", - config_file, - "--device", - json_message["port"], - ] +class EsphomeRunHandler(EsphomePortCommandWebSocket): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + """Build the command to run.""" + return await self.build_device_command(["run"], json_message) class EsphomeCompileHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) - command = ["esphome", "--dashboard", "compile"] + command = [*DASHBOARD_COMMAND, "compile"] if json_message.get("only_generate", False): command.append("--only-generate") command.append(config_file) @@ -472,39 +485,39 @@ class EsphomeCompileHandler(EsphomeCommandWebSocket): class EsphomeValidateHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) - command = ["esphome", "--dashboard", "config", config_file] + command = [*DASHBOARD_COMMAND, "config", config_file] if not settings.streamer_mode: command.append("--show-secrets") return command class EsphomeCleanMqttHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) - return ["esphome", "--dashboard", "clean-mqtt", config_file] + return [*DASHBOARD_COMMAND, "clean-mqtt", config_file] class EsphomeCleanHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) - return ["esphome", "--dashboard", "clean", config_file] + return [*DASHBOARD_COMMAND, "clean", config_file] class EsphomeVscodeHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - return ["esphome", "--dashboard", "-q", "vscode", "dummy"] + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + return [*DASHBOARD_COMMAND, "-q", "vscode", "dummy"] class EsphomeAceEditorHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - return ["esphome", "--dashboard", "-q", "vscode", "--ace", settings.config_dir] + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + return [*DASHBOARD_COMMAND, "-q", "vscode", "--ace", settings.config_dir] class EsphomeUpdateAllHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): - return ["esphome", "--dashboard", "update-all", settings.config_dir] + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + return [*DASHBOARD_COMMAND, "update-all", settings.config_dir] class SerialPortRequestHandler(BaseHandler): @@ -838,8 +851,9 @@ class DashboardEntry: class ListDevicesHandler(BaseHandler): @authenticated - def get(self): - entries = _list_dashboard_entries() + async def get(self): + loop = asyncio.get_running_loop() + entries = await loop.run_in_executor(None, _list_dashboard_entries) self.set_header("content-type", "application/json") configured = {entry.name for entry in entries} self.write( @@ -963,24 +977,39 @@ class BoardsRequestHandler(BaseHandler): self.write(json.dumps(output)) -class MDNSStatusThread(threading.Thread): - def __init__(self): - """Initialize the MDNSStatusThread.""" +class MDNSStatus: + """Class that updates the mdns status.""" + + def __init__(self) -> None: + """Initialize the MDNSStatus class.""" super().__init__() + self.aiozc: AsyncEsphomeZeroconf | None = None # This is the current mdns state for each host (True, False, None) self.host_mdns_state: dict[str, bool | None] = {} # This is the hostnames to filenames mapping self.host_name_to_filename: dict[str, str] = {} + self.filename_to_host_name: dict[str, str] = {} # This is a set of host names to track (i.e no_mdns = false) self.host_name_with_mdns_enabled: set[set] = set() - self._refresh_hosts() + self._loop = asyncio.get_running_loop() - def _refresh_hosts(self): + def filename_to_host_name_thread_safe(self, filename: str) -> str | None: + """Resolve a filename to an address in a thread-safe manner.""" + return self.filename_to_host_name.get(filename) + + async def async_resolve_host(self, host_name: str) -> str | None: + """Resolve a host name to an address in a thread-safe manner.""" + if aiozc := self.aiozc: + return await aiozc.async_resolve_host(host_name) + return None + + async def async_refresh_hosts(self): """Refresh the hosts to track.""" - entries = _list_dashboard_entries() + entries = await self._loop.run_in_executor(None, _list_dashboard_entries) host_name_with_mdns_enabled = self.host_name_with_mdns_enabled host_mdns_state = self.host_mdns_state host_name_to_filename = self.host_name_to_filename + filename_to_host_name = self.filename_to_host_name for entry in entries: name = entry.name @@ -1003,11 +1032,13 @@ class MDNSStatusThread(threading.Thread): # so when we get an mdns update we can map it back # to the filename host_name_to_filename[name] = filename + filename_to_host_name[filename] = name - def run(self): + async def async_run(self) -> None: global IMPORT_RESULT - zc = EsphomeZeroconf() + aiozc = AsyncEsphomeZeroconf() + self.aiozc = aiozc host_mdns_state = self.host_mdns_state host_name_to_filename = self.host_name_to_filename host_name_with_mdns_enabled = self.host_name_with_mdns_enabled @@ -1020,21 +1051,23 @@ class MDNSStatusThread(threading.Thread): filename = host_name_to_filename[name] PING_RESULT[filename] = result - self._refresh_hosts() stat = DashboardStatus(on_update) imports = DashboardImportDiscovery() browser = DashboardBrowser( - zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback] + aiozc.zeroconf, + ESPHOME_SERVICE_TYPE, + [stat.browser_callback, imports.browser_callback], ) while not STOP_EVENT.is_set(): - self._refresh_hosts() + await self.async_refresh_hosts() IMPORT_RESULT = imports.import_state - PING_REQUEST.wait() - PING_REQUEST.clear() + await PING_REQUEST.async_wait() + PING_REQUEST.async_clear() - browser.cancel() - zc.close() + await browser.async_cancel() + await aiozc.async_close() + self.aiozc = None class PingStatusThread(threading.Thread): @@ -1211,11 +1244,26 @@ class UndoDeleteRequestHandler(BaseHandler): shutil.move(os.path.join(trash_path, configuration), config_file) +class MDNSContainer: + def __init__(self) -> None: + """Initialize the MDNSContainer.""" + self._mdns: MDNSStatus | None = None + + def set_mdns(self, mdns: MDNSStatus) -> None: + """Set the MDNSStatus instance.""" + self._mdns = mdns + + def get_mdns(self) -> MDNSStatus | None: + """Return the MDNSStatus instance.""" + return self._mdns + + PING_RESULT: dict = {} IMPORT_RESULT = {} STOP_EVENT = threading.Event() -PING_REQUEST = threading.Event() +PING_REQUEST = ThreadedAsyncEvent() MQTT_PING_REQUEST = threading.Event() +MDNS_CONTAINER = MDNSContainer() class LoginHandler(BaseHandler): @@ -1478,6 +1526,16 @@ def start_web_server(args): storage.save(path) settings.cookie_secret = storage.cookie_secret + try: + asyncio.run(async_start_web_server(args)) + except KeyboardInterrupt: + pass + + +async def async_start_web_server(args): + loop = asyncio.get_event_loop() + PING_REQUEST.async_setup(loop, asyncio.Event()) + app = make_app(args.verbose) if args.socket is not None: _LOGGER.info( @@ -1502,25 +1560,36 @@ def start_web_server(args): webbrowser.open(f"http://{args.address}:{args.port}") + mdns_task: asyncio.Task | None = None + ping_status_thread: PingStatusThread | None = None if settings.status_use_ping: - status_thread = PingStatusThread() + ping_status_thread = PingStatusThread() + ping_status_thread.start() else: - status_thread = MDNSStatusThread() - status_thread.start() + mdns_status = MDNSStatus() + await mdns_status.async_refresh_hosts() + MDNS_CONTAINER.set_mdns(mdns_status) + mdns_task = asyncio.create_task(mdns_status.async_run()) if settings.status_use_mqtt: status_thread_mqtt = MqttStatusThread() status_thread_mqtt.start() + shutdown_event = asyncio.Event() try: - tornado.ioloop.IOLoop.current().start() - except KeyboardInterrupt: + await shutdown_event.wait() + finally: _LOGGER.info("Shutting down...") STOP_EVENT.set() PING_REQUEST.set() - status_thread.join() + if ping_status_thread: + ping_status_thread.join() + MDNS_CONTAINER.set_mdns(None) + if mdns_task: + mdns_task.cancel() if settings.status_use_mqtt: status_thread_mqtt.join() MQTT_PING_REQUEST.set() if args.socket is not None: os.remove(args.socket) + await asyncio.sleep(0) diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index f4cb7f080b..956e348e07 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -1,22 +1,21 @@ from __future__ import annotations +import asyncio import logging from dataclasses import dataclass from typing import Callable -from zeroconf import ( - IPVersion, - ServiceBrowser, - ServiceInfo, - ServiceStateChange, - Zeroconf, -) +from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf from esphome.storage_json import StorageJSON, ext_storage_path _LOGGER = logging.getLogger(__name__) +_BACKGROUND_TASKS: set[asyncio.Task] = set() + + class HostResolver(ServiceInfo): """Resolve a host name to an IP address.""" @@ -65,7 +64,7 @@ class DiscoveredImport: network: str -class DashboardBrowser(ServiceBrowser): +class DashboardBrowser(AsyncServiceBrowser): """A class to browse for ESPHome nodes.""" @@ -94,7 +93,28 @@ class DashboardImportDiscovery: # Ignore updates for devices that are not in the import state return - info = zeroconf.get_service_info(service_type, name) + info = AsyncServiceInfo( + service_type, + name, + ) + if info.load_from_cache(zeroconf): + self._process_service_info(name, info) + return + task = asyncio.create_task( + self._async_process_service_info(zeroconf, info, service_type, name) + ) + _BACKGROUND_TASKS.add(task) + task.add_done_callback(_BACKGROUND_TASKS.discard) + + async def _async_process_service_info( + self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str + ) -> None: + """Process a service info.""" + if await info.async_request(zeroconf): + self._process_service_info(name, info) + + def _process_service_info(self, name: str, info: ServiceInfo) -> None: + """Process a service info.""" _LOGGER.debug("-> resolved info: %s", info) if info is None: return @@ -146,14 +166,32 @@ class DashboardImportDiscovery: ) +def _make_host_resolver(host: str) -> HostResolver: + """Create a new HostResolver for the given host name.""" + name = host.partition(".")[0] + info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") + return info + + class EsphomeZeroconf(Zeroconf): def resolve_host(self, host: str, timeout: float = 3.0) -> str | None: """Resolve a host name to an IP address.""" - name = host.partition(".")[0] - info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") + info = _make_host_resolver(host) if ( info.load_from_cache(self) or (timeout and info.request(self, timeout * 1000)) ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): return str(addresses[0]) return None + + +class AsyncEsphomeZeroconf(AsyncZeroconf): + async def async_resolve_host(self, host: str, timeout: float = 3.0) -> str | None: + """Resolve a host name to an IP address.""" + info = _make_host_resolver(host) + if ( + info.load_from_cache(self.zeroconf) + or (timeout and await info.async_request(self.zeroconf, timeout * 1000)) + ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): + return str(addresses[0]) + return None