From cd9bf29df112506387cb32f4fadaa19df21484a0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 19 Nov 2023 21:29:40 -0600 Subject: [PATCH] dashboard: Add lookup by name to entries (#5790) * Add lookup by name to entries * adj * tweak * tweak * tweak * tweak * tweak * tweak * preen --- esphome/dashboard/entries.py | 24 +++++++++++++--- esphome/dashboard/status/mdns.py | 47 ++++++-------------------------- esphome/dashboard/web_server.py | 5 ++-- 3 files changed, 32 insertions(+), 44 deletions(-) diff --git a/esphome/dashboard/entries.py b/esphome/dashboard/entries.py index 42b3a2e743..c5d7f3a245 100644 --- a/esphome/dashboard/entries.py +++ b/esphome/dashboard/entries.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import logging import os +from collections import defaultdict from typing import TYPE_CHECKING, Any from esphome import const, util @@ -68,6 +69,7 @@ class DashboardEntries: "_entry_states", "_loaded_entries", "_update_lock", + "_name_to_entry", ) def __init__(self, dashboard: ESPHomeDashboard) -> None: @@ -83,11 +85,16 @@ class DashboardEntries: self._entries: dict[str, DashboardEntry] = {} self._loaded_entries = False self._update_lock = asyncio.Lock() + self._name_to_entry: dict[str, set[DashboardEntry]] = defaultdict(set) def get(self, path: str) -> DashboardEntry | None: """Get an entry by path.""" return self._entries.get(path) + def get_by_name(self, name: str) -> set[DashboardEntry] | None: + """Get an entry by name.""" + return self._name_to_entry.get(name) + async def _async_all(self) -> list[DashboardEntry]: """Return all entries.""" return list(self._entries.values()) @@ -155,6 +162,7 @@ class DashboardEntries: None, self._get_path_to_cache_key ) entries = self._entries + name_to_entry = self._name_to_entry added: dict[DashboardEntry, DashboardCacheKeyType] = {} updated: dict[DashboardEntry, DashboardCacheKeyType] = {} removed: set[DashboardEntry] = { @@ -162,14 +170,17 @@ class DashboardEntries: for filename, entry in entries.items() if filename not in path_to_cache_key } + original_names: dict[DashboardEntry, str] = {} for path, cache_key in path_to_cache_key.items(): - if entry := entries.get(path): - if entry.cache_key != cache_key: - updated[entry] = cache_key - else: + if not (entry := entries.get(path)): entry = DashboardEntry(path, cache_key) added[entry] = cache_key + continue + + if entry.cache_key != cache_key: + updated[entry] = cache_key + original_names[entry] = entry.name if added or updated: await self._loop.run_in_executor( @@ -179,13 +190,18 @@ class DashboardEntries: bus = self._dashboard.bus for entry in added: entries[entry.path] = entry + name_to_entry[entry.name].add(entry) bus.async_fire(EVENT_ENTRY_ADDED, {"entry": entry}) for entry in removed: del entries[entry.path] + name_to_entry[entry.name].discard(entry) bus.async_fire(EVENT_ENTRY_REMOVED, {"entry": entry}) for entry in updated: + if (original_name := original_names[entry]) != (current_name := entry.name): + name_to_entry[original_name].discard(entry) + name_to_entry[current_name].add(entry) bus.async_fire(EVENT_ENTRY_UPDATED, {"entry": entry}) def _get_path_to_cache_key(self) -> dict[str, DashboardCacheKeyType]: diff --git a/esphome/dashboard/status/mdns.py b/esphome/dashboard/status/mdns.py index cbe3b3309e..4f4fa560d0 100644 --- a/esphome/dashboard/status/mdns.py +++ b/esphome/dashboard/status/mdns.py @@ -24,17 +24,8 @@ class MDNSStatus: self.aiozc: AsyncEsphomeZeroconf | None = None # This is the current mdns state for each host (True, False, None) self.host_mdns_state: dict[str, bool | None] = {} - # This is the hostnames to path mapping - self.host_name_to_path: dict[str, str] = {} - self.path_to_host_name: dict[str, str] = {} - # This is a set of host names to track (i.e no_mdns = false) - self.host_name_with_mdns_enabled: set[set] = set() self._loop = asyncio.get_running_loop() - def get_path_to_host_name(self, path: str) -> str | None: - """Resolve a path to an address in a thread-safe manner.""" - return self.path_to_host_name.get(path) - async def async_resolve_host(self, host_name: str) -> str | None: """Resolve a host name to an address in a thread-safe manner.""" if aiozc := self.aiozc: @@ -44,53 +35,32 @@ class MDNSStatus: async def async_refresh_hosts(self): """Refresh the hosts to track.""" dashboard = DASHBOARD - current_entries = dashboard.entries.async_all() - host_name_with_mdns_enabled = self.host_name_with_mdns_enabled host_mdns_state = self.host_mdns_state - host_name_to_path = self.host_name_to_path - path_to_host_name = self.path_to_host_name entries = dashboard.entries - - for entry in current_entries: - name = entry.name - # If no_mdns is set, remove it from the set + for entry in entries.async_all(): if entry.no_mdns: - host_name_with_mdns_enabled.discard(name) continue - - # We are tracking this host - host_name_with_mdns_enabled.add(name) - path = entry.path - # If we just adopted/imported this host, we likely # already have a state for it, so we should make sure # to set it so the dashboard shows it as online - if (online := host_mdns_state.get(name, SENTINEL)) != SENTINEL: + if (online := host_mdns_state.get(entry.name, SENTINEL)) != SENTINEL: entries.async_set_state(entry, bool_to_entry_state(online)) - # Make sure the mapping is up to date - # so when we get an mdns update we can map it back - # to the filename - host_name_to_path[name] = path - path_to_host_name[path] = name - async def async_run(self) -> None: dashboard = DASHBOARD entries = dashboard.entries aiozc = AsyncEsphomeZeroconf() self.aiozc = aiozc host_mdns_state = self.host_mdns_state - host_name_to_path = self.host_name_to_path - host_name_with_mdns_enabled = self.host_name_with_mdns_enabled def on_update(dat: dict[str, bool | None]) -> None: """Update the entry state.""" for name, result in dat.items(): host_mdns_state[name] = result - if name not in host_name_with_mdns_enabled: - continue - if entry := entries.get(host_name_to_path[name]): - entries.async_set_state(entry, bool_to_entry_state(result)) + if matching_entries := entries.get_by_name(name): + for entry in matching_entries: + if not entry.no_mdns: + entries.async_set_state(entry, bool_to_entry_state(result)) stat = DashboardStatus(on_update) imports = DashboardImportDiscovery() @@ -102,10 +72,11 @@ class MDNSStatus: [stat.browser_callback, imports.browser_callback], ) + ping_request = dashboard.ping_request while not dashboard.stop_event.is_set(): await self.async_refresh_hosts() - await dashboard.ping_request.wait() - dashboard.ping_request.clear() + await ping_request.wait() + ping_request.clear() await browser.async_cancel() await aiozc.async_close() diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index 8901da095f..7c5f653b5b 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -271,14 +271,15 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): ) -> list[str]: """Build the command to run.""" dashboard = DASHBOARD + entries = dashboard.entries configuration = json_message["configuration"] config_file = settings.rel_path(configuration) port = json_message["port"] if ( port == "OTA" and (mdns := dashboard.mdns_status) - and (host_name := mdns.get_path_to_host_name(config_file)) - and (address := await mdns.async_resolve_host(host_name)) + and (entry := entries.get(config_file)) + and (address := await mdns.async_resolve_host(entry.name)) ): port = address