From 3c243e663f18723570cf313d502da85fec2a063e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 17 Nov 2023 18:33:10 -0600 Subject: [PATCH] dashboard: Add support for firing events (#5775) * dashboard: fire events when entry is updated or state changes * dashboard: fire events when entry is updated or state changes * dashboard: fire events when entry is updated or state changes * tweaks * fixes * remove typing_extensions * rename for asyncio * rename for asyncio * rename for asyncio * preen * lint * lint * move dict converter * lint --- esphome/dashboard/const.py | 8 ++ esphome/dashboard/core.py | 49 ++++++++++- esphome/dashboard/entries.py | 141 +++++++++++++++++++++++++++---- esphome/dashboard/enum.py | 19 +++++ esphome/dashboard/status/mdns.py | 48 ++++++----- esphome/dashboard/status/mqtt.py | 17 ++-- esphome/dashboard/status/ping.py | 10 +-- esphome/dashboard/web_server.py | 37 ++++---- 8 files changed, 251 insertions(+), 78 deletions(-) create mode 100644 esphome/dashboard/const.py create mode 100644 esphome/dashboard/enum.py diff --git a/esphome/dashboard/const.py b/esphome/dashboard/const.py new file mode 100644 index 0000000000..ed2b81d3e8 --- /dev/null +++ b/esphome/dashboard/const.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +EVENT_ENTRY_ADDED = "entry_added" +EVENT_ENTRY_REMOVED = "entry_removed" +EVENT_ENTRY_UPDATED = "entry_updated" +EVENT_ENTRY_STATE_CHANGED = "entry_state_changed" + +SENTINEL = object() diff --git a/esphome/dashboard/core.py b/esphome/dashboard/core.py index f18da92d80..ffec9784e8 100644 --- a/esphome/dashboard/core.py +++ b/esphome/dashboard/core.py @@ -3,7 +3,9 @@ from __future__ import annotations import asyncio import logging import threading -from typing import TYPE_CHECKING +from dataclasses import dataclass +from functools import partial +from typing import TYPE_CHECKING, Any, Callable from ..zeroconf import DiscoveredImport from .entries import DashboardEntries @@ -12,16 +14,55 @@ from .settings import DashboardSettings if TYPE_CHECKING: from .status.mdns import MDNSStatus + _LOGGER = logging.getLogger(__name__) +@dataclass +class Event: + """Dashboard Event.""" + + event_type: str + data: dict[str, Any] + + +class EventBus: + """Dashboard event bus.""" + + def __init__(self) -> None: + """Initialize the Dashboard event bus.""" + self._listeners: dict[str, set[Callable[[Event], None]]] = {} + + def async_add_listener( + self, event_type: str, listener: Callable[[Event], None] + ) -> Callable[[], None]: + """Add a listener to the event bus.""" + self._listeners.setdefault(event_type, set()).add(listener) + return partial(self._async_remove_listener, event_type, listener) + + def _async_remove_listener( + self, event_type: str, listener: Callable[[Event], None] + ) -> None: + """Remove a listener from the event bus.""" + self._listeners[event_type].discard(listener) + + def async_fire(self, event_type: str, event_data: dict[str, Any]) -> None: + """Fire an event.""" + event = Event(event_type, event_data) + + _LOGGER.debug("Firing event: %s", event) + + for listener in self._listeners.get(event_type, set()): + listener(event) + + class ESPHomeDashboard: """Class that represents the dashboard.""" __slots__ = ( + "bus", "entries", "loop", - "ping_result", "import_result", "stop_event", "ping_request", @@ -32,9 +73,9 @@ class ESPHomeDashboard: def __init__(self) -> None: """Initialize the ESPHomeDashboard.""" + self.bus = EventBus() self.entries: DashboardEntries | None = None self.loop: asyncio.AbstractEventLoop | None = None - self.ping_result: dict[str, bool | None] = {} self.import_result: dict[str, DiscoveredImport] = {} self.stop_event = threading.Event() self.ping_request: asyncio.Event | None = None @@ -46,7 +87,7 @@ class ESPHomeDashboard: """Setup the dashboard.""" self.loop = asyncio.get_running_loop() self.ping_request = asyncio.Event() - self.entries = DashboardEntries(self.settings.config_dir) + self.entries = DashboardEntries(self) async def async_run(self) -> None: """Run the dashboard.""" diff --git a/esphome/dashboard/entries.py b/esphome/dashboard/entries.py index ff539fc620..42b3a2e743 100644 --- a/esphome/dashboard/entries.py +++ b/esphome/dashboard/entries.py @@ -3,24 +3,78 @@ from __future__ import annotations import asyncio import logging import os +from typing import TYPE_CHECKING, Any from esphome import const, util from esphome.storage_json import StorageJSON, ext_storage_path +from .const import ( + EVENT_ENTRY_ADDED, + EVENT_ENTRY_REMOVED, + EVENT_ENTRY_STATE_CHANGED, + EVENT_ENTRY_UPDATED, +) +from .enum import StrEnum + +if TYPE_CHECKING: + from .core import ESPHomeDashboard + _LOGGER = logging.getLogger(__name__) + DashboardCacheKeyType = tuple[int, int, float, int] +# Currently EntryState is a simple +# online/offline/unknown enum, but in the future +# it may be expanded to include more states + + +class EntryState(StrEnum): + ONLINE = "online" + OFFLINE = "offline" + UNKNOWN = "unknown" + + +_BOOL_TO_ENTRY_STATE = { + True: EntryState.ONLINE, + False: EntryState.OFFLINE, + None: EntryState.UNKNOWN, +} +_ENTRY_STATE_TO_BOOL = { + EntryState.ONLINE: True, + EntryState.OFFLINE: False, + EntryState.UNKNOWN: None, +} + + +def bool_to_entry_state(value: bool) -> EntryState: + """Convert a bool to an entry state.""" + return _BOOL_TO_ENTRY_STATE[value] + + +def entry_state_to_bool(value: EntryState) -> bool | None: + """Convert an entry state to a bool.""" + return _ENTRY_STATE_TO_BOOL[value] + class DashboardEntries: """Represents all dashboard entries.""" - __slots__ = ("_loop", "_config_dir", "_entries", "_loaded_entries", "_update_lock") + __slots__ = ( + "_dashboard", + "_loop", + "_config_dir", + "_entries", + "_entry_states", + "_loaded_entries", + "_update_lock", + ) - def __init__(self, config_dir: str) -> None: + def __init__(self, dashboard: ESPHomeDashboard) -> None: """Initialize the DashboardEntries.""" + self._dashboard = dashboard self._loop = asyncio.get_running_loop() - self._config_dir = config_dir + self._config_dir = dashboard.settings.config_dir # Entries are stored as # { # "path/to/file.yaml": DashboardEntry, @@ -46,6 +100,25 @@ class DashboardEntries: """Return all entries.""" return list(self._entries.values()) + def set_state(self, entry: DashboardEntry, state: EntryState) -> None: + """Set the state for an entry.""" + asyncio.run_coroutine_threadsafe( + self._async_set_state(entry, state), self._loop + ).result() + + async def _async_set_state(self, entry: DashboardEntry, state: EntryState) -> None: + """Set the state for an entry.""" + self.async_set_state(entry, state) + + def async_set_state(self, entry: DashboardEntry, state: EntryState) -> None: + """Set the state for an entry.""" + if entry.state == state: + return + entry.state = state + self._dashboard.bus.async_fire( + EVENT_ENTRY_STATE_CHANGED, {"entry": entry, "state": state} + ) + async def async_request_update_entries(self) -> None: """Request an update of the dashboard entries from disk. @@ -81,16 +154,17 @@ class DashboardEntries: path_to_cache_key = await self._loop.run_in_executor( None, self._get_path_to_cache_key ) + entries = self._entries added: dict[DashboardEntry, DashboardCacheKeyType] = {} updated: dict[DashboardEntry, DashboardCacheKeyType] = {} removed: set[DashboardEntry] = { entry - for filename, entry in self._entries.items() + for filename, entry in entries.items() if filename not in path_to_cache_key } - entries = self._entries + for path, cache_key in path_to_cache_key.items(): - if entry := self._entries.get(path): + if entry := entries.get(path): if entry.cache_key != cache_key: updated[entry] = cache_key else: @@ -102,17 +176,17 @@ class DashboardEntries: None, self._load_entries, {**added, **updated} ) + bus = self._dashboard.bus for entry in added: - _LOGGER.debug("Added dashboard entry %s", entry.path) entries[entry.path] = entry + bus.async_fire(EVENT_ENTRY_ADDED, {"entry": entry}) - if entry in removed: - _LOGGER.debug("Removed dashboard entry %s", entry.path) - entries.pop(entry.path) + for entry in removed: + del entries[entry.path] + bus.async_fire(EVENT_ENTRY_REMOVED, {"entry": entry}) for entry in updated: - _LOGGER.debug("Updated dashboard entry %s", entry.path) - # In the future we can fire events when entries are added/removed/updated + bus.async_fire(EVENT_ENTRY_UPDATED, {"entry": entry}) def _get_path_to_cache_key(self) -> dict[str, DashboardCacheKeyType]: """Return a dict of path to cache key.""" @@ -152,29 +226,64 @@ class DashboardEntry: This class is thread-safe and read-only. """ - __slots__ = ("path", "filename", "_storage_path", "cache_key", "storage") + __slots__ = ( + "path", + "filename", + "_storage_path", + "cache_key", + "storage", + "state", + "_to_dict", + ) def __init__(self, path: str, cache_key: DashboardCacheKeyType) -> None: """Initialize the DashboardEntry.""" self.path = path - self.filename = os.path.basename(path) + self.filename: str = os.path.basename(path) self._storage_path = ext_storage_path(self.filename) self.cache_key = cache_key self.storage: StorageJSON | None = None + self.state = EntryState.UNKNOWN + self._to_dict: dict[str, Any] | None = None def __repr__(self): """Return the representation of this entry.""" return ( - f"DashboardEntry({self.path} " + f"DashboardEntry(path={self.path} " f"address={self.address} " f"web_port={self.web_port} " f"name={self.name} " - f"no_mdns={self.no_mdns})" + f"no_mdns={self.no_mdns} " + f"state={self.state} " + ")" ) + def to_dict(self) -> dict[str, Any]: + """Return a dict representation of this entry. + + The dict includes the loaded configuration but not + the current state of the entry. + """ + if self._to_dict is None: + self._to_dict = { + "name": self.name, + "friendly_name": self.friendly_name, + "configuration": self.filename, + "loaded_integrations": self.loaded_integrations, + "deployed_version": self.update_old, + "current_version": self.update_new, + "path": self.path, + "comment": self.comment, + "address": self.address, + "web_port": self.web_port, + "target_platform": self.target_platform, + } + return self._to_dict + def load_from_disk(self, cache_key: DashboardCacheKeyType | None = None) -> None: """Load this entry from disk.""" self.storage = StorageJSON.load(self._storage_path) + self._to_dict = None # # Currently StorageJSON.load() will return None if the file does not exist # diff --git a/esphome/dashboard/enum.py b/esphome/dashboard/enum.py new file mode 100644 index 0000000000..6aff21620e --- /dev/null +++ b/esphome/dashboard/enum.py @@ -0,0 +1,19 @@ +"""Enum backports from standard lib.""" +from __future__ import annotations + +from enum import Enum +from typing import Any + + +class StrEnum(str, Enum): + """Partial backport of Python 3.11's StrEnum for our basic use cases.""" + + def __new__(cls, value: str, *args: Any, **kwargs: Any) -> StrEnum: + """Create a new StrEnum instance.""" + if not isinstance(value, str): + raise TypeError(f"{value!r} is not a string") + return super().__new__(cls, value, *args, **kwargs) + + def __str__(self) -> str: + """Return self.value.""" + return str(self.value) diff --git a/esphome/dashboard/status/mdns.py b/esphome/dashboard/status/mdns.py index 51d11390b7..cbe3b3309e 100644 --- a/esphome/dashboard/status/mdns.py +++ b/esphome/dashboard/status/mdns.py @@ -10,7 +10,9 @@ from esphome.zeroconf import ( DashboardStatus, ) +from ..const import SENTINEL from ..core import DASHBOARD +from ..entries import bool_to_entry_state class MDNSStatus: @@ -22,16 +24,16 @@ class MDNSStatus: self.aiozc: AsyncEsphomeZeroconf | None = None # This is the current mdns state for each host (True, False, None) self.host_mdns_state: dict[str, bool | None] = {} - # This is the hostnames to filenames mapping - self.host_name_to_filename: dict[str, str] = {} - self.filename_to_host_name: dict[str, str] = {} + # This is the hostnames to path mapping + self.host_name_to_path: dict[str, str] = {} + self.path_to_host_name: dict[str, str] = {} # This is a set of host names to track (i.e no_mdns = false) self.host_name_with_mdns_enabled: set[set] = set() self._loop = asyncio.get_running_loop() - def filename_to_host_name_thread_safe(self, filename: str) -> str | None: - """Resolve a filename to an address in a thread-safe manner.""" - return self.filename_to_host_name.get(filename) + def get_path_to_host_name(self, path: str) -> str | None: + """Resolve a path to an address in a thread-safe manner.""" + return self.path_to_host_name.get(path) async def async_resolve_host(self, host_name: str) -> str | None: """Resolve a host name to an address in a thread-safe manner.""" @@ -42,14 +44,14 @@ class MDNSStatus: async def async_refresh_hosts(self): """Refresh the hosts to track.""" dashboard = DASHBOARD - entries = dashboard.entries.async_all() + current_entries = dashboard.entries.async_all() host_name_with_mdns_enabled = self.host_name_with_mdns_enabled host_mdns_state = self.host_mdns_state - host_name_to_filename = self.host_name_to_filename - filename_to_host_name = self.filename_to_host_name - ping_result = dashboard.ping_result + host_name_to_path = self.host_name_to_path + path_to_host_name = self.path_to_host_name + entries = dashboard.entries - for entry in entries: + for entry in current_entries: name = entry.name # If no_mdns is set, remove it from the set if entry.no_mdns: @@ -58,37 +60,37 @@ class MDNSStatus: # We are tracking this host host_name_with_mdns_enabled.add(name) - filename = entry.filename + path = entry.path # If we just adopted/imported this host, we likely # already have a state for it, so we should make sure # to set it so the dashboard shows it as online - if name in host_mdns_state: - ping_result[filename] = host_mdns_state[name] + if (online := host_mdns_state.get(name, SENTINEL)) != SENTINEL: + entries.async_set_state(entry, bool_to_entry_state(online)) # Make sure the mapping is up to date # so when we get an mdns update we can map it back # to the filename - host_name_to_filename[name] = filename - filename_to_host_name[filename] = name + host_name_to_path[name] = path + path_to_host_name[path] = name async def async_run(self) -> None: dashboard = DASHBOARD - + entries = dashboard.entries aiozc = AsyncEsphomeZeroconf() self.aiozc = aiozc host_mdns_state = self.host_mdns_state - host_name_to_filename = self.host_name_to_filename + host_name_to_path = self.host_name_to_path host_name_with_mdns_enabled = self.host_name_with_mdns_enabled - ping_result = dashboard.ping_result def on_update(dat: dict[str, bool | None]) -> None: - """Update the global PING_RESULT dict.""" + """Update the entry state.""" for name, result in dat.items(): host_mdns_state[name] = result - if name in host_name_with_mdns_enabled: - filename = host_name_to_filename[name] - ping_result[filename] = result + if name not in host_name_with_mdns_enabled: + continue + if entry := entries.get(host_name_to_path[name]): + entries.async_set_state(entry, bool_to_entry_state(result)) stat = DashboardStatus(on_update) imports = DashboardImportDiscovery() diff --git a/esphome/dashboard/status/mqtt.py b/esphome/dashboard/status/mqtt.py index 2fd3a332a7..8c35dd2535 100644 --- a/esphome/dashboard/status/mqtt.py +++ b/esphome/dashboard/status/mqtt.py @@ -8,6 +8,7 @@ import threading from esphome import mqtt from ..core import DASHBOARD +from ..entries import EntryState class MqttStatusThread(threading.Thread): @@ -16,22 +17,23 @@ class MqttStatusThread(threading.Thread): def run(self) -> None: """Run the status thread.""" dashboard = DASHBOARD - entries = dashboard.entries.all() + entries = dashboard.entries + current_entries = entries.all() config = mqtt.config_from_env() topic = "esphome/discover/#" def on_message(client, userdata, msg): - nonlocal entries + nonlocal current_entries payload = msg.payload.decode(errors="backslashreplace") if len(payload) > 0: data = json.loads(payload) if "name" not in data: return - for entry in entries: + for entry in current_entries: if entry.name == data["name"]: - dashboard.ping_result[entry.filename] = True + entries.set_state(entry, EntryState.ONLINE) return def on_connect(client, userdata, flags, return_code): @@ -51,12 +53,11 @@ class MqttStatusThread(threading.Thread): client.loop_start() while not dashboard.stop_event.wait(2): - entries = dashboard.entries.all() - + current_entries = entries.all() # will be set to true on on_message - for entry in entries: + for entry in current_entries: if entry.no_mdns: - dashboard.ping_result[entry.filename] = False + entries.set_state(entry, EntryState.OFFLINE) client.publish("esphome/discover", None, retain=False) dashboard.mqtt_ping_request.wait() diff --git a/esphome/dashboard/status/ping.py b/esphome/dashboard/status/ping.py index 35fb2259f0..d8281d9de1 100644 --- a/esphome/dashboard/status/ping.py +++ b/esphome/dashboard/status/ping.py @@ -5,7 +5,7 @@ import os from typing import cast from ..core import DASHBOARD -from ..entries import DashboardEntry +from ..entries import DashboardEntry, bool_to_entry_state from ..util.itertools import chunked from ..util.subprocess import async_system_command_status @@ -26,14 +26,14 @@ class PingStatus: async def async_run(self) -> None: """Run the ping status.""" dashboard = DASHBOARD + entries = dashboard.entries while not dashboard.stop_event.is_set(): # Only ping if the dashboard is open await dashboard.ping_request.wait() - dashboard.ping_result.clear() - entries = dashboard.entries.async_all() + current_entries = dashboard.entries.async_all() to_ping: list[DashboardEntry] = [ - entry for entry in entries if entry.address is not None + entry for entry in current_entries if entry.address is not None ] for ping_group in chunked(to_ping, 16): ping_group = cast(list[DashboardEntry], ping_group) @@ -46,4 +46,4 @@ class PingStatus: result = False elif isinstance(result, BaseException): raise result - dashboard.ping_result[entry.filename] = result + entries.async_set_state(entry, bool_to_entry_state(result)) diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index 9a5de0a933..9972808948 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -37,6 +37,7 @@ from esphome.util import get_serial_ports, shlex_quote from esphome.yaml_util import FastestAvailableSafeLoader from .core import DASHBOARD +from .entries import EntryState, entry_state_to_bool from .util.subprocess import async_run_system_command from .util.text import friendly_name_slugify @@ -275,7 +276,7 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): if ( port == "OTA" and (mdns := dashboard.mdns_status) - and (host_name := mdns.filename_to_host_name_thread_safe(configuration)) + and (host_name := mdns.get_path_to_host_name(config_file)) and (address := await mdns.async_resolve_host(host_name)) ): port = address @@ -315,7 +316,9 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket): return # Remove the old ping result from the cache - DASHBOARD.ping_result.pop(self.old_name, None) + entries = DASHBOARD.entries + if entry := entries.get(self.old_name): + entries.async_set_state(entry, EntryState.UNKNOWN) class EsphomeUploadHandler(EsphomePortCommandWebSocket): @@ -609,22 +612,7 @@ class ListDevicesHandler(BaseHandler): self.write( json.dumps( { - "configured": [ - { - "name": entry.name, - "friendly_name": entry.friendly_name, - "configuration": entry.filename, - "loaded_integrations": entry.loaded_integrations, - "deployed_version": entry.update_old, - "current_version": entry.update_new, - "path": entry.path, - "comment": entry.comment, - "address": entry.address, - "web_port": entry.web_port, - "target_platform": entry.target_platform, - } - for entry in entries - ], + "configured": [entry.to_dict() for entry in entries], "importable": [ { "name": res.device_name, @@ -728,7 +716,15 @@ class PingRequestHandler(BaseHandler): if settings.status_use_mqtt: dashboard.mqtt_ping_request.set() self.set_header("content-type", "application/json") - self.write(json.dumps(dashboard.ping_result)) + + self.write( + json.dumps( + { + entry.filename: entry_state_to_bool(entry.state) + for entry in dashboard.entries.async_all() + } + ) + ) class InfoRequestHandler(BaseHandler): @@ -785,9 +781,6 @@ class DeleteRequestHandler(BaseHandler): if build_folder is not None: shutil.rmtree(build_folder, os.path.join(trash_path, name)) - # Remove the old ping result from the cache - DASHBOARD.ping_result.pop(configuration, None) - class UndoDeleteRequestHandler(BaseHandler): @authenticated