dashboard: Add lookup by name to entries (#5790)

* Add lookup by name to entries

* adj

* tweak

* tweak

* tweak

* tweak

* tweak

* tweak

* preen
This commit is contained in:
J. Nick Koston 2023-11-19 21:29:40 -06:00 committed by GitHub
parent 4e4fe3c26d
commit cd9bf29df1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 44 deletions

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
from collections import defaultdict
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from esphome import const, util from esphome import const, util
@ -68,6 +69,7 @@ class DashboardEntries:
"_entry_states", "_entry_states",
"_loaded_entries", "_loaded_entries",
"_update_lock", "_update_lock",
"_name_to_entry",
) )
def __init__(self, dashboard: ESPHomeDashboard) -> None: def __init__(self, dashboard: ESPHomeDashboard) -> None:
@ -83,11 +85,16 @@ class DashboardEntries:
self._entries: dict[str, DashboardEntry] = {} self._entries: dict[str, DashboardEntry] = {}
self._loaded_entries = False self._loaded_entries = False
self._update_lock = asyncio.Lock() self._update_lock = asyncio.Lock()
self._name_to_entry: dict[str, set[DashboardEntry]] = defaultdict(set)
def get(self, path: str) -> DashboardEntry | None: def get(self, path: str) -> DashboardEntry | None:
"""Get an entry by path.""" """Get an entry by path."""
return self._entries.get(path) return self._entries.get(path)
def get_by_name(self, name: str) -> set[DashboardEntry] | None:
"""Get an entry by name."""
return self._name_to_entry.get(name)
async def _async_all(self) -> list[DashboardEntry]: async def _async_all(self) -> list[DashboardEntry]:
"""Return all entries.""" """Return all entries."""
return list(self._entries.values()) return list(self._entries.values())
@ -155,6 +162,7 @@ class DashboardEntries:
None, self._get_path_to_cache_key None, self._get_path_to_cache_key
) )
entries = self._entries entries = self._entries
name_to_entry = self._name_to_entry
added: dict[DashboardEntry, DashboardCacheKeyType] = {} added: dict[DashboardEntry, DashboardCacheKeyType] = {}
updated: dict[DashboardEntry, DashboardCacheKeyType] = {} updated: dict[DashboardEntry, DashboardCacheKeyType] = {}
removed: set[DashboardEntry] = { removed: set[DashboardEntry] = {
@ -162,14 +170,17 @@ class DashboardEntries:
for filename, entry in entries.items() for filename, entry in entries.items()
if filename not in path_to_cache_key if filename not in path_to_cache_key
} }
original_names: dict[DashboardEntry, str] = {}
for path, cache_key in path_to_cache_key.items(): for path, cache_key in path_to_cache_key.items():
if entry := entries.get(path): if not (entry := entries.get(path)):
if entry.cache_key != cache_key:
updated[entry] = cache_key
else:
entry = DashboardEntry(path, cache_key) entry = DashboardEntry(path, cache_key)
added[entry] = cache_key added[entry] = cache_key
continue
if entry.cache_key != cache_key:
updated[entry] = cache_key
original_names[entry] = entry.name
if added or updated: if added or updated:
await self._loop.run_in_executor( await self._loop.run_in_executor(
@ -179,13 +190,18 @@ class DashboardEntries:
bus = self._dashboard.bus bus = self._dashboard.bus
for entry in added: for entry in added:
entries[entry.path] = entry entries[entry.path] = entry
name_to_entry[entry.name].add(entry)
bus.async_fire(EVENT_ENTRY_ADDED, {"entry": entry}) bus.async_fire(EVENT_ENTRY_ADDED, {"entry": entry})
for entry in removed: for entry in removed:
del entries[entry.path] del entries[entry.path]
name_to_entry[entry.name].discard(entry)
bus.async_fire(EVENT_ENTRY_REMOVED, {"entry": entry}) bus.async_fire(EVENT_ENTRY_REMOVED, {"entry": entry})
for entry in updated: for entry in updated:
if (original_name := original_names[entry]) != (current_name := entry.name):
name_to_entry[original_name].discard(entry)
name_to_entry[current_name].add(entry)
bus.async_fire(EVENT_ENTRY_UPDATED, {"entry": entry}) bus.async_fire(EVENT_ENTRY_UPDATED, {"entry": entry})
def _get_path_to_cache_key(self) -> dict[str, DashboardCacheKeyType]: def _get_path_to_cache_key(self) -> dict[str, DashboardCacheKeyType]:

View file

@ -24,17 +24,8 @@ class MDNSStatus:
self.aiozc: AsyncEsphomeZeroconf | None = None self.aiozc: AsyncEsphomeZeroconf | None = None
# This is the current mdns state for each host (True, False, None) # This is the current mdns state for each host (True, False, None)
self.host_mdns_state: dict[str, bool | None] = {} self.host_mdns_state: dict[str, bool | None] = {}
# This is the hostnames to path mapping
self.host_name_to_path: dict[str, str] = {}
self.path_to_host_name: dict[str, str] = {}
# This is a set of host names to track (i.e no_mdns = false)
self.host_name_with_mdns_enabled: set[set] = set()
self._loop = asyncio.get_running_loop() self._loop = asyncio.get_running_loop()
def get_path_to_host_name(self, path: str) -> str | None:
"""Resolve a path to an address in a thread-safe manner."""
return self.path_to_host_name.get(path)
async def async_resolve_host(self, host_name: str) -> str | None: async def async_resolve_host(self, host_name: str) -> str | None:
"""Resolve a host name to an address in a thread-safe manner.""" """Resolve a host name to an address in a thread-safe manner."""
if aiozc := self.aiozc: if aiozc := self.aiozc:
@ -44,52 +35,31 @@ class MDNSStatus:
async def async_refresh_hosts(self): async def async_refresh_hosts(self):
"""Refresh the hosts to track.""" """Refresh the hosts to track."""
dashboard = DASHBOARD dashboard = DASHBOARD
current_entries = dashboard.entries.async_all()
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
host_mdns_state = self.host_mdns_state host_mdns_state = self.host_mdns_state
host_name_to_path = self.host_name_to_path
path_to_host_name = self.path_to_host_name
entries = dashboard.entries entries = dashboard.entries
for entry in entries.async_all():
for entry in current_entries:
name = entry.name
# If no_mdns is set, remove it from the set
if entry.no_mdns: if entry.no_mdns:
host_name_with_mdns_enabled.discard(name)
continue continue
# We are tracking this host
host_name_with_mdns_enabled.add(name)
path = entry.path
# If we just adopted/imported this host, we likely # If we just adopted/imported this host, we likely
# already have a state for it, so we should make sure # already have a state for it, so we should make sure
# to set it so the dashboard shows it as online # to set it so the dashboard shows it as online
if (online := host_mdns_state.get(name, SENTINEL)) != SENTINEL: if (online := host_mdns_state.get(entry.name, SENTINEL)) != SENTINEL:
entries.async_set_state(entry, bool_to_entry_state(online)) entries.async_set_state(entry, bool_to_entry_state(online))
# Make sure the mapping is up to date
# so when we get an mdns update we can map it back
# to the filename
host_name_to_path[name] = path
path_to_host_name[path] = name
async def async_run(self) -> None: async def async_run(self) -> None:
dashboard = DASHBOARD dashboard = DASHBOARD
entries = dashboard.entries entries = dashboard.entries
aiozc = AsyncEsphomeZeroconf() aiozc = AsyncEsphomeZeroconf()
self.aiozc = aiozc self.aiozc = aiozc
host_mdns_state = self.host_mdns_state host_mdns_state = self.host_mdns_state
host_name_to_path = self.host_name_to_path
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
def on_update(dat: dict[str, bool | None]) -> None: def on_update(dat: dict[str, bool | None]) -> None:
"""Update the entry state.""" """Update the entry state."""
for name, result in dat.items(): for name, result in dat.items():
host_mdns_state[name] = result host_mdns_state[name] = result
if name not in host_name_with_mdns_enabled: if matching_entries := entries.get_by_name(name):
continue for entry in matching_entries:
if entry := entries.get(host_name_to_path[name]): if not entry.no_mdns:
entries.async_set_state(entry, bool_to_entry_state(result)) entries.async_set_state(entry, bool_to_entry_state(result))
stat = DashboardStatus(on_update) stat = DashboardStatus(on_update)
@ -102,10 +72,11 @@ class MDNSStatus:
[stat.browser_callback, imports.browser_callback], [stat.browser_callback, imports.browser_callback],
) )
ping_request = dashboard.ping_request
while not dashboard.stop_event.is_set(): while not dashboard.stop_event.is_set():
await self.async_refresh_hosts() await self.async_refresh_hosts()
await dashboard.ping_request.wait() await ping_request.wait()
dashboard.ping_request.clear() ping_request.clear()
await browser.async_cancel() await browser.async_cancel()
await aiozc.async_close() await aiozc.async_close()

View file

@ -271,14 +271,15 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
) -> list[str]: ) -> list[str]:
"""Build the command to run.""" """Build the command to run."""
dashboard = DASHBOARD dashboard = DASHBOARD
entries = dashboard.entries
configuration = json_message["configuration"] configuration = json_message["configuration"]
config_file = settings.rel_path(configuration) config_file = settings.rel_path(configuration)
port = json_message["port"] port = json_message["port"]
if ( if (
port == "OTA" port == "OTA"
and (mdns := dashboard.mdns_status) and (mdns := dashboard.mdns_status)
and (host_name := mdns.get_path_to_host_name(config_file)) and (entry := entries.get(config_file))
and (address := await mdns.async_resolve_host(host_name)) and (address := await mdns.async_resolve_host(entry.name))
): ):
port = address port = address