dashboard: Use mdns cache when available if device connection is OTA (#5724)

* Use mdns or freshen cache when device connection is OTA

Since we already have a service browser running, we likely
already know the IP of the deivce we want to connect to so
we can replace OTA with the address to avoid the esphome
app having to look it up again

* isort

* Fix zeroconf name resolution refactoring error

HostResolver should get the type as the first arg instead
of the name

* no i/o

* tornado support native coros

* lint

* use new tornado start methods

* use new tornado start methods

* use new tornado start methods

* break

* lint

* lint

* typing, missing awaits

* io in executor

* missed one

* fix: missing if

* stale comment

* rename run_command to build_device_command since it does not actually run anything
This commit is contained in:
J. Nick Koston 2023-11-14 20:21:44 -06:00 committed by GitHub
parent cdcb25be8e
commit 214b419db2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 255 additions and 92 deletions

View file

@ -0,0 +1,56 @@
from __future__ import annotations
import asyncio
import threading
class ThreadedAsyncEvent:
"""This is a shim to allow the asyncio event to be used in a threaded context.
When more of the code is moved to asyncio, this can be removed.
"""
def __init__(self) -> None:
"""Initialize the ThreadedAsyncEvent."""
self.event = threading.Event()
self.async_event: asyncio.Event | None = None
self.loop: asyncio.AbstractEventLoop | None = None
def async_setup(
self, loop: asyncio.AbstractEventLoop, async_event: asyncio.Event
) -> None:
"""Set the asyncio.Event instance."""
self.loop = loop
self.async_event = async_event
def async_set(self) -> None:
"""Set the asyncio.Event instance."""
self.async_event.set()
self.event.set()
def set(self) -> None:
"""Set the event."""
self.loop.call_soon_threadsafe(self.async_event.set)
self.event.set()
def wait(self) -> None:
"""Wait for the event."""
self.event.wait()
async def async_wait(self) -> None:
"""Wait the event async."""
await self.async_event.wait()
def clear(self) -> None:
"""Clear the event."""
self.loop.call_soon_threadsafe(self.async_event.clear)
self.event.clear()
def async_clear(self) -> None:
"""Clear the event async."""
self.async_event.clear()
self.event.clear()
def is_set(self) -> bool:
"""Return if the event is set."""
return self.event.is_set()

View file

@ -18,6 +18,7 @@ import shutil
import subprocess
import threading
from pathlib import Path
from typing import Any
import tornado
import tornado.concurrent
@ -46,11 +47,12 @@ from esphome.storage_json import (
from esphome.util import get_serial_ports, shlex_quote
from esphome.zeroconf import (
ESPHOME_SERVICE_TYPE,
AsyncEsphomeZeroconf,
DashboardBrowser,
DashboardImportDiscovery,
DashboardStatus,
EsphomeZeroconf,
)
from .async_adapter import ThreadedAsyncEvent
from .util import friendly_name_slugify, password_hash
@ -288,7 +290,10 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
self._use_popen = os.name == "nt"
@authenticated
def on_message(self, message):
async def on_message( # pylint: disable=invalid-overridden-method
self, message: str
) -> None:
# Since tornado 4.5, on_message is allowed to be a coroutine
# Messages are always JSON, 500 when not
json_message = json.loads(message)
type_ = json_message["type"]
@ -298,14 +303,14 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
_LOGGER.warning("Requested unknown message type %s", type_)
return
handlers[type_](self, json_message)
await handlers[type_](self, json_message)
@websocket_method("spawn")
def handle_spawn(self, json_message):
async def handle_spawn(self, json_message: dict[str, Any]) -> None:
if self._proc is not None:
# spawn can only be called once
return
command = self.build_command(json_message)
command = await self.build_command(json_message)
_LOGGER.info("Running command '%s'", " ".join(shlex_quote(x) for x in command))
if self._use_popen:
@ -336,7 +341,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
return self._proc is not None and self._proc.returncode is None
@websocket_method("stdin")
def handle_stdin(self, json_message):
async def handle_stdin(self, json_message: dict[str, Any]) -> None:
if not self.is_process_active:
return
text: str = json_message["data"]
@ -345,7 +350,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
self._proc.stdin.write(data)
@tornado.gen.coroutine
def _redirect_stdout(self):
def _redirect_stdout(self) -> None:
reg = b"[\n\r]"
while True:
@ -364,7 +369,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
_LOGGER.debug("> stdout: %s", text)
self.write_message({"event": "line", "data": text})
def _stdout_thread(self):
def _stdout_thread(self) -> None:
if not self._use_popen:
return
while True:
@ -377,13 +382,13 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
self._proc.wait(1.0)
self._queue.put_nowait(None)
def _proc_on_exit(self, returncode):
def _proc_on_exit(self, returncode: int) -> None:
if not self._is_closed:
# Check if the proc was not forcibly closed
_LOGGER.info("Process exited with return code %s", returncode)
self.write_message({"event": "exit", "code": returncode})
def on_close(self):
def on_close(self) -> None:
# Check if proc exists (if 'start' has been run)
if self.is_process_active:
_LOGGER.debug("Terminating process")
@ -394,32 +399,54 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
# Shutdown proc on WS close
self._is_closed = True
def build_command(self, json_message):
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
raise NotImplementedError
class EsphomeLogsHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
DASHBOARD_COMMAND = ["esphome", "--dashboard"]
class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
"""Base class for commands that require a port."""
async def build_device_command(
self, args: list[str], json_message: dict[str, Any]
) -> list[str]:
"""Build the command to run."""
configuration = json_message["configuration"]
config_file = settings.rel_path(configuration)
port = json_message["port"]
if (
port == "OTA"
and (mdns := MDNS_CONTAINER.get_mdns())
and (host_name := mdns.filename_to_host_name_thread_safe(configuration))
and (address := await mdns.async_resolve_host(host_name))
):
port = address
return [
"esphome",
"--dashboard",
"logs",
*DASHBOARD_COMMAND,
*args,
config_file,
"--device",
json_message["port"],
port,
]
class EsphomeLogsHandler(EsphomePortCommandWebSocket):
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
"""Build the command to run."""
return await self.build_device_command(["logs"], json_message)
class EsphomeRenameHandler(EsphomeCommandWebSocket):
old_name: str
def build_command(self, json_message):
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
config_file = settings.rel_path(json_message["configuration"])
self.old_name = json_message["configuration"]
return [
"esphome",
"--dashboard",
*DASHBOARD_COMMAND,
"rename",
config_file,
json_message["newName"],
@ -435,36 +462,22 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket):
PING_RESULT.pop(self.old_name, None)
class EsphomeUploadHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
return [
"esphome",
"--dashboard",
"upload",
config_file,
"--device",
json_message["port"],
]
class EsphomeUploadHandler(EsphomePortCommandWebSocket):
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
"""Build the command to run."""
return await self.build_device_command(["upload"], json_message)
class EsphomeRunHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
return [
"esphome",
"--dashboard",
"run",
config_file,
"--device",
json_message["port"],
]
class EsphomeRunHandler(EsphomePortCommandWebSocket):
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
"""Build the command to run."""
return await self.build_device_command(["run"], json_message)
class EsphomeCompileHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
config_file = settings.rel_path(json_message["configuration"])
command = ["esphome", "--dashboard", "compile"]
command = [*DASHBOARD_COMMAND, "compile"]
if json_message.get("only_generate", False):
command.append("--only-generate")
command.append(config_file)
@ -472,39 +485,39 @@ class EsphomeCompileHandler(EsphomeCommandWebSocket):
class EsphomeValidateHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
config_file = settings.rel_path(json_message["configuration"])
command = ["esphome", "--dashboard", "config", config_file]
command = [*DASHBOARD_COMMAND, "config", config_file]
if not settings.streamer_mode:
command.append("--show-secrets")
return command
class EsphomeCleanMqttHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
config_file = settings.rel_path(json_message["configuration"])
return ["esphome", "--dashboard", "clean-mqtt", config_file]
return [*DASHBOARD_COMMAND, "clean-mqtt", config_file]
class EsphomeCleanHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
config_file = settings.rel_path(json_message["configuration"])
return ["esphome", "--dashboard", "clean", config_file]
return [*DASHBOARD_COMMAND, "clean", config_file]
class EsphomeVscodeHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
return ["esphome", "--dashboard", "-q", "vscode", "dummy"]
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
return [*DASHBOARD_COMMAND, "-q", "vscode", "dummy"]
class EsphomeAceEditorHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
return ["esphome", "--dashboard", "-q", "vscode", "--ace", settings.config_dir]
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
return [*DASHBOARD_COMMAND, "-q", "vscode", "--ace", settings.config_dir]
class EsphomeUpdateAllHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
return ["esphome", "--dashboard", "update-all", settings.config_dir]
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
return [*DASHBOARD_COMMAND, "update-all", settings.config_dir]
class SerialPortRequestHandler(BaseHandler):
@ -838,8 +851,9 @@ class DashboardEntry:
class ListDevicesHandler(BaseHandler):
@authenticated
def get(self):
entries = _list_dashboard_entries()
async def get(self):
loop = asyncio.get_running_loop()
entries = await loop.run_in_executor(None, _list_dashboard_entries)
self.set_header("content-type", "application/json")
configured = {entry.name for entry in entries}
self.write(
@ -963,24 +977,39 @@ class BoardsRequestHandler(BaseHandler):
self.write(json.dumps(output))
class MDNSStatusThread(threading.Thread):
def __init__(self):
"""Initialize the MDNSStatusThread."""
class MDNSStatus:
"""Class that updates the mdns status."""
def __init__(self) -> None:
"""Initialize the MDNSStatus class."""
super().__init__()
self.aiozc: AsyncEsphomeZeroconf | None = None
# This is the current mdns state for each host (True, False, None)
self.host_mdns_state: dict[str, bool | None] = {}
# This is the hostnames to filenames mapping
self.host_name_to_filename: dict[str, str] = {}
self.filename_to_host_name: dict[str, str] = {}
# This is a set of host names to track (i.e no_mdns = false)
self.host_name_with_mdns_enabled: set[set] = set()
self._refresh_hosts()
self._loop = asyncio.get_running_loop()
def _refresh_hosts(self):
def filename_to_host_name_thread_safe(self, filename: str) -> str | None:
"""Resolve a filename to an address in a thread-safe manner."""
return self.filename_to_host_name.get(filename)
async def async_resolve_host(self, host_name: str) -> str | None:
"""Resolve a host name to an address in a thread-safe manner."""
if aiozc := self.aiozc:
return await aiozc.async_resolve_host(host_name)
return None
async def async_refresh_hosts(self):
"""Refresh the hosts to track."""
entries = _list_dashboard_entries()
entries = await self._loop.run_in_executor(None, _list_dashboard_entries)
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
host_mdns_state = self.host_mdns_state
host_name_to_filename = self.host_name_to_filename
filename_to_host_name = self.filename_to_host_name
for entry in entries:
name = entry.name
@ -1003,11 +1032,13 @@ class MDNSStatusThread(threading.Thread):
# so when we get an mdns update we can map it back
# to the filename
host_name_to_filename[name] = filename
filename_to_host_name[filename] = name
def run(self):
async def async_run(self) -> None:
global IMPORT_RESULT
zc = EsphomeZeroconf()
aiozc = AsyncEsphomeZeroconf()
self.aiozc = aiozc
host_mdns_state = self.host_mdns_state
host_name_to_filename = self.host_name_to_filename
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
@ -1020,21 +1051,23 @@ class MDNSStatusThread(threading.Thread):
filename = host_name_to_filename[name]
PING_RESULT[filename] = result
self._refresh_hosts()
stat = DashboardStatus(on_update)
imports = DashboardImportDiscovery()
browser = DashboardBrowser(
zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback]
aiozc.zeroconf,
ESPHOME_SERVICE_TYPE,
[stat.browser_callback, imports.browser_callback],
)
while not STOP_EVENT.is_set():
self._refresh_hosts()
await self.async_refresh_hosts()
IMPORT_RESULT = imports.import_state
PING_REQUEST.wait()
PING_REQUEST.clear()
await PING_REQUEST.async_wait()
PING_REQUEST.async_clear()
browser.cancel()
zc.close()
await browser.async_cancel()
await aiozc.async_close()
self.aiozc = None
class PingStatusThread(threading.Thread):
@ -1211,11 +1244,26 @@ class UndoDeleteRequestHandler(BaseHandler):
shutil.move(os.path.join(trash_path, configuration), config_file)
class MDNSContainer:
def __init__(self) -> None:
"""Initialize the MDNSContainer."""
self._mdns: MDNSStatus | None = None
def set_mdns(self, mdns: MDNSStatus) -> None:
"""Set the MDNSStatus instance."""
self._mdns = mdns
def get_mdns(self) -> MDNSStatus | None:
"""Return the MDNSStatus instance."""
return self._mdns
PING_RESULT: dict = {}
IMPORT_RESULT = {}
STOP_EVENT = threading.Event()
PING_REQUEST = threading.Event()
PING_REQUEST = ThreadedAsyncEvent()
MQTT_PING_REQUEST = threading.Event()
MDNS_CONTAINER = MDNSContainer()
class LoginHandler(BaseHandler):
@ -1478,6 +1526,16 @@ def start_web_server(args):
storage.save(path)
settings.cookie_secret = storage.cookie_secret
try:
asyncio.run(async_start_web_server(args))
except KeyboardInterrupt:
pass
async def async_start_web_server(args):
loop = asyncio.get_event_loop()
PING_REQUEST.async_setup(loop, asyncio.Event())
app = make_app(args.verbose)
if args.socket is not None:
_LOGGER.info(
@ -1502,25 +1560,36 @@ def start_web_server(args):
webbrowser.open(f"http://{args.address}:{args.port}")
mdns_task: asyncio.Task | None = None
ping_status_thread: PingStatusThread | None = None
if settings.status_use_ping:
status_thread = PingStatusThread()
ping_status_thread = PingStatusThread()
ping_status_thread.start()
else:
status_thread = MDNSStatusThread()
status_thread.start()
mdns_status = MDNSStatus()
await mdns_status.async_refresh_hosts()
MDNS_CONTAINER.set_mdns(mdns_status)
mdns_task = asyncio.create_task(mdns_status.async_run())
if settings.status_use_mqtt:
status_thread_mqtt = MqttStatusThread()
status_thread_mqtt.start()
shutdown_event = asyncio.Event()
try:
tornado.ioloop.IOLoop.current().start()
except KeyboardInterrupt:
await shutdown_event.wait()
finally:
_LOGGER.info("Shutting down...")
STOP_EVENT.set()
PING_REQUEST.set()
status_thread.join()
if ping_status_thread:
ping_status_thread.join()
MDNS_CONTAINER.set_mdns(None)
if mdns_task:
mdns_task.cancel()
if settings.status_use_mqtt:
status_thread_mqtt.join()
MQTT_PING_REQUEST.set()
if args.socket is not None:
os.remove(args.socket)
await asyncio.sleep(0)

View file

@ -1,22 +1,21 @@
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Callable
from zeroconf import (
IPVersion,
ServiceBrowser,
ServiceInfo,
ServiceStateChange,
Zeroconf,
)
from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
from esphome.storage_json import StorageJSON, ext_storage_path
_LOGGER = logging.getLogger(__name__)
_BACKGROUND_TASKS: set[asyncio.Task] = set()
class HostResolver(ServiceInfo):
"""Resolve a host name to an IP address."""
@ -65,7 +64,7 @@ class DiscoveredImport:
network: str
class DashboardBrowser(ServiceBrowser):
class DashboardBrowser(AsyncServiceBrowser):
"""A class to browse for ESPHome nodes."""
@ -94,7 +93,28 @@ class DashboardImportDiscovery:
# Ignore updates for devices that are not in the import state
return
info = zeroconf.get_service_info(service_type, name)
info = AsyncServiceInfo(
service_type,
name,
)
if info.load_from_cache(zeroconf):
self._process_service_info(name, info)
return
task = asyncio.create_task(
self._async_process_service_info(zeroconf, info, service_type, name)
)
_BACKGROUND_TASKS.add(task)
task.add_done_callback(_BACKGROUND_TASKS.discard)
async def _async_process_service_info(
self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str
) -> None:
"""Process a service info."""
if await info.async_request(zeroconf):
self._process_service_info(name, info)
def _process_service_info(self, name: str, info: ServiceInfo) -> None:
"""Process a service info."""
_LOGGER.debug("-> resolved info: %s", info)
if info is None:
return
@ -146,14 +166,32 @@ class DashboardImportDiscovery:
)
def _make_host_resolver(host: str) -> HostResolver:
"""Create a new HostResolver for the given host name."""
name = host.partition(".")[0]
info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}")
return info
class EsphomeZeroconf(Zeroconf):
def resolve_host(self, host: str, timeout: float = 3.0) -> str | None:
"""Resolve a host name to an IP address."""
name = host.partition(".")[0]
info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}")
info = _make_host_resolver(host)
if (
info.load_from_cache(self)
or (timeout and info.request(self, timeout * 1000))
) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)):
return str(addresses[0])
return None
class AsyncEsphomeZeroconf(AsyncZeroconf):
async def async_resolve_host(self, host: str, timeout: float = 3.0) -> str | None:
"""Resolve a host name to an IP address."""
info = _make_host_resolver(host)
if (
info.load_from_cache(self.zeroconf)
or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)):
return str(addresses[0])
return None