Refactor dashboard zeroconf support (#5681)

This commit is contained in:
J. Nick Koston 2023-11-06 16:07:59 -06:00 committed by GitHub
parent b978985aa1
commit fce59819f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 123 additions and 148 deletions

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import base64 import base64
import binascii import binascii
import codecs import codecs
@ -15,7 +17,6 @@ import shutil
import subprocess import subprocess
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Optional
import tornado import tornado
import tornado.concurrent import tornado.concurrent
@ -42,7 +43,13 @@ from esphome.storage_json import (
trash_storage_path, trash_storage_path,
) )
from esphome.util import get_serial_ports, shlex_quote from esphome.util import get_serial_ports, shlex_quote
from esphome.zeroconf import DashboardImportDiscovery, DashboardStatus, EsphomeZeroconf from esphome.zeroconf import (
ESPHOME_SERVICE_TYPE,
DashboardBrowser,
DashboardImportDiscovery,
DashboardStatus,
EsphomeZeroconf,
)
from .util import friendly_name_slugify, password_hash from .util import friendly_name_slugify, password_hash
@ -517,6 +524,8 @@ class ImportRequestHandler(BaseHandler):
network, network,
encryption, encryption,
) )
# Make sure the device gets marked online right away
PING_REQUEST.set()
except FileExistsError: except FileExistsError:
self.set_status(500) self.set_status(500)
self.write("File already exists") self.write("File already exists")
@ -542,13 +551,11 @@ class DownloadListRequestHandler(BaseHandler):
self.send_error(404) self.send_error(404)
return return
from esphome.components.esp32 import ( from esphome.components.esp32 import VARIANTS as ESP32_VARIANTS
get_download_types as esp32_types, from esphome.components.esp32 import get_download_types as esp32_types
VARIANTS as ESP32_VARIANTS,
)
from esphome.components.esp8266 import get_download_types as esp8266_types from esphome.components.esp8266 import get_download_types as esp8266_types
from esphome.components.rp2040 import get_download_types as rp2040_types
from esphome.components.libretiny import get_download_types as libretiny_types from esphome.components.libretiny import get_download_types as libretiny_types
from esphome.components.rp2040 import get_download_types as rp2040_types
downloads = [] downloads = []
platform = storage_json.target_platform.lower() platform = storage_json.target_platform.lower()
@ -661,12 +668,21 @@ class DashboardEntry:
self._storage = None self._storage = None
self._loaded_storage = False self._loaded_storage = False
def __repr__(self):
return (
f"DashboardEntry({self.path} "
f"address={self.address} "
f"web_port={self.web_port} "
f"name={self.name} "
f"no_mdns={self.no_mdns})"
)
@property @property
def filename(self): def filename(self):
return os.path.basename(self.path) return os.path.basename(self.path)
@property @property
def storage(self) -> Optional[StorageJSON]: def storage(self) -> StorageJSON | None:
if not self._loaded_storage: if not self._loaded_storage:
self._storage = StorageJSON.load(ext_storage_path(self.filename)) self._storage = StorageJSON.load(ext_storage_path(self.filename))
self._loaded_storage = True self._loaded_storage = True
@ -831,10 +847,10 @@ class PrometheusServiceDiscoveryHandler(BaseHandler):
class BoardsRequestHandler(BaseHandler): class BoardsRequestHandler(BaseHandler):
@authenticated @authenticated
def get(self, platform: str): def get(self, platform: str):
from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS
from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS
from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS
from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS
from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS
from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS
platform_to_boards = { platform_to_boards = {
@ -865,35 +881,76 @@ class BoardsRequestHandler(BaseHandler):
class MDNSStatusThread(threading.Thread): class MDNSStatusThread(threading.Thread):
def __init__(self):
"""Initialize the MDNSStatusThread."""
super().__init__()
# This is the current mdns state for each host (True, False, None)
self.host_mdns_state: dict[str, bool | None] = {}
# This is the hostnames to filenames mapping
self.host_name_to_filename: dict[str, str] = {}
# This is a set of host names to track (i.e no_mdns = false)
self.host_name_with_mdns_enabled: set[set] = set()
self._refresh_hosts()
def _refresh_hosts(self):
"""Refresh the hosts to track."""
entries = _list_dashboard_entries()
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
host_mdns_state = self.host_mdns_state
host_name_to_filename = self.host_name_to_filename
for entry in entries:
name = entry.name
# If no_mdns is set, remove it from the set
if entry.no_mdns:
host_name_with_mdns_enabled.discard(name)
continue
# We are tracking this host
host_name_with_mdns_enabled.add(name)
filename = entry.filename
# If we just adopted/imported this host, we likely
# already have a state for it, so we should make sure
# to set it so the dashboard shows it as online
if name in host_mdns_state:
PING_RESULT[filename] = host_mdns_state[name]
# Make sure the mapping is up to date
# so when we get an mdns update we can map it back
# to the filename
host_name_to_filename[name] = filename
def run(self): def run(self):
global IMPORT_RESULT global IMPORT_RESULT
zc = EsphomeZeroconf() zc = EsphomeZeroconf()
host_mdns_state = self.host_mdns_state
host_name_to_filename = self.host_name_to_filename
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
def on_update(dat): def on_update(dat: dict[str, bool | None]) -> None:
for key, b in dat.items(): """Update the global PING_RESULT dict."""
PING_RESULT[key] = b for name, result in dat.items():
host_mdns_state[name] = result
if name in host_name_with_mdns_enabled:
filename = host_name_to_filename[name]
PING_RESULT[filename] = result
stat = DashboardStatus(zc, on_update) self._refresh_hosts()
imports = DashboardImportDiscovery(zc) stat = DashboardStatus(on_update)
imports = DashboardImportDiscovery()
browser = DashboardBrowser(
zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback]
)
stat.start()
while not STOP_EVENT.is_set(): while not STOP_EVENT.is_set():
entries = _list_dashboard_entries() self._refresh_hosts()
hosts = {}
for entry in entries:
if entry.no_mdns is not True:
hosts[entry.filename] = f"{entry.name}.local."
stat.request_query(hosts)
IMPORT_RESULT = imports.import_state IMPORT_RESULT = imports.import_state
PING_REQUEST.wait() PING_REQUEST.wait()
PING_REQUEST.clear() PING_REQUEST.clear()
stat.stop() browser.cancel()
stat.join()
imports.cancel()
zc.close() zc.close()

View file

@ -1,130 +1,49 @@
from __future__ import annotations
import logging import logging
import socket
import threading
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Callable
from zeroconf import ( from zeroconf import (
DNSAddress, IPVersion,
DNSOutgoing,
DNSQuestion,
RecordUpdate,
RecordUpdateListener,
ServiceBrowser, ServiceBrowser,
ServiceInfo,
ServiceStateChange, ServiceStateChange,
Zeroconf, Zeroconf,
current_time_millis,
) )
from esphome.storage_json import StorageJSON, ext_storage_path from esphome.storage_json import StorageJSON, ext_storage_path
_CLASS_IN = 1
_FLAGS_QR_QUERY = 0x0000 # query
_TYPE_A = 1
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class HostResolver(RecordUpdateListener): class HostResolver(ServiceInfo):
"""Resolve a host name to an IP address.""" """Resolve a host name to an IP address."""
def __init__(self, name: str): @property
self.name = name def _is_complete(self) -> bool:
self.address: Optional[bytes] = None """The ServiceInfo has all expected properties."""
return bool(self._ipv4_addresses)
def async_update_records(
self, zc: Zeroconf, now: float, records: list[RecordUpdate]
) -> None:
"""Update multiple records in one shot.
This will run in zeroconf's event loop thread so it
must be thread-safe.
"""
for record_update in records:
record, _ = record_update
if record is None:
continue
if record.type == _TYPE_A:
assert isinstance(record, DNSAddress)
if record.name == self.name:
self.address = record.address
def request(self, zc: Zeroconf, timeout: float) -> bool:
now = time.time()
delay = 0.2
next_ = now + delay
last = now + timeout
try:
zc.add_listener(self, None)
while self.address is None:
if last <= now:
# Timeout
return False
if next_ <= now:
out = DNSOutgoing(_FLAGS_QR_QUERY)
out.add_question(DNSQuestion(self.name, _TYPE_A, _CLASS_IN))
zc.send(out)
next_ = now + delay
delay *= 2
time.sleep(min(next_, last) - now)
now = time.time()
finally:
zc.remove_listener(self)
return True
class DashboardStatus(threading.Thread): class DashboardStatus:
PING_AFTER = 15 * 1000 # Send new mDNS request after 15 seconds def __init__(self, on_update: Callable[[dict[str, bool | None], []]]) -> None:
OFFLINE_AFTER = PING_AFTER * 2 # Offline if no mDNS response after 30 seconds """Initialize the dashboard status."""
def __init__(self, zc: Zeroconf, on_update) -> None:
threading.Thread.__init__(self)
self.zc = zc
self.query_hosts: set[str] = set()
self.key_to_host: dict[str, str] = {}
self.stop_event = threading.Event()
self.query_event = threading.Event()
self.on_update = on_update self.on_update = on_update
def request_query(self, hosts: dict[str, str]) -> None: def browser_callback(
self.query_hosts = set(hosts.values()) self,
self.key_to_host = hosts zeroconf: Zeroconf,
self.query_event.set() service_type: str,
name: str,
def stop(self) -> None: state_change: ServiceStateChange,
self.stop_event.set() ) -> None:
self.query_event.set() """Handle a service update."""
short_name = name.partition(".")[0]
def host_status(self, key: str) -> bool: if state_change == ServiceStateChange.Removed:
entries = self.zc.cache.entries_with_name(key) self.on_update({short_name: False})
if not entries: elif state_change in (ServiceStateChange.Updated, ServiceStateChange.Added):
return False self.on_update({short_name: True})
now = current_time_millis()
return any(
(entry.created + DashboardStatus.OFFLINE_AFTER) >= now for entry in entries
)
def run(self) -> None:
while not self.stop_event.is_set():
self.on_update(
{key: self.host_status(host) for key, host in self.key_to_host.items()}
)
now = current_time_millis()
for host in self.query_hosts:
entries = self.zc.cache.entries_with_name(host)
if not entries or all(
(entry.created + DashboardStatus.PING_AFTER) <= now
for entry in entries
):
out = DNSOutgoing(_FLAGS_QR_QUERY)
out.add_question(DNSQuestion(host, _TYPE_A, _CLASS_IN))
self.zc.send(out)
self.query_event.wait()
self.query_event.clear()
ESPHOME_SERVICE_TYPE = "_esphomelib._tcp.local." ESPHOME_SERVICE_TYPE = "_esphomelib._tcp.local."
@ -138,7 +57,7 @@ TXT_RECORD_VERSION = b"version"
@dataclass @dataclass
class DiscoveredImport: class DiscoveredImport:
friendly_name: Optional[str] friendly_name: str | None
device_name: str device_name: str
package_import_url: str package_import_url: str
project_name: str project_name: str
@ -146,15 +65,15 @@ class DiscoveredImport:
network: str network: str
class DashboardBrowser(ServiceBrowser):
"""A class to browse for ESPHome nodes."""
class DashboardImportDiscovery: class DashboardImportDiscovery:
def __init__(self, zc: Zeroconf) -> None: def __init__(self) -> None:
self.zc = zc
self.service_browser = ServiceBrowser(
self.zc, ESPHOME_SERVICE_TYPE, [self._on_update]
)
self.import_state: dict[str, DiscoveredImport] = {} self.import_state: dict[str, DiscoveredImport] = {}
def _on_update( def browser_callback(
self, self,
zeroconf: Zeroconf, zeroconf: Zeroconf,
service_type: str, service_type: str,
@ -167,8 +86,6 @@ class DashboardImportDiscovery:
name, name,
state_change, state_change,
) )
if service_type != ESPHOME_SERVICE_TYPE:
return
if state_change == ServiceStateChange.Removed: if state_change == ServiceStateChange.Removed:
self.import_state.pop(name, None) self.import_state.pop(name, None)
return return
@ -212,9 +129,6 @@ class DashboardImportDiscovery:
network=network, network=network,
) )
def cancel(self) -> None:
self.service_browser.cancel()
def update_device_mdns(self, node_name: str, version: str): def update_device_mdns(self, node_name: str, version: str):
storage_path = ext_storage_path(node_name + ".yaml") storage_path = ext_storage_path(node_name + ".yaml")
storage_json = StorageJSON.load(storage_path) storage_json = StorageJSON.load(storage_path)
@ -234,7 +148,11 @@ class DashboardImportDiscovery:
class EsphomeZeroconf(Zeroconf): class EsphomeZeroconf(Zeroconf):
def resolve_host(self, host: str, timeout=3.0): def resolve_host(self, host: str, timeout=3.0):
info = HostResolver(host) """Resolve a host name to an IP address."""
if info.request(self, timeout): name = host.partition(".")[0]
return socket.inet_ntoa(info.address) info = HostResolver(f"{name}.{ESPHOME_SERVICE_TYPE}", ESPHOME_SERVICE_TYPE)
if (info.load_from_cache(self) or info.request(self, timeout * 1000)) and (
addresses := info.ip_addresses_by_version(IPVersion.V4Only)
):
return str(addresses[0])
return None return None