dashboard: refactor ping implementation to be more efficient (#6002)

This commit is contained in:
J. Nick Koston 2024-01-08 15:35:43 -10:00 committed by GitHub
parent 87301a2e76
commit 6dfdcff66c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 255 additions and 21 deletions

View file

@ -4,5 +4,7 @@ EVENT_ENTRY_ADDED = "entry_added"
EVENT_ENTRY_REMOVED = "entry_removed" EVENT_ENTRY_REMOVED = "entry_removed"
EVENT_ENTRY_UPDATED = "entry_updated" EVENT_ENTRY_UPDATED = "entry_updated"
EVENT_ENTRY_STATE_CHANGED = "entry_state_changed" EVENT_ENTRY_STATE_CHANGED = "entry_state_changed"
MAX_EXECUTOR_WORKERS = 48
SENTINEL = object() SENTINEL = object()

View file

@ -8,6 +8,7 @@ from functools import partial
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
from ..zeroconf import DiscoveredImport from ..zeroconf import DiscoveredImport
from .dns import DNSCache
from .entries import DashboardEntries from .entries import DashboardEntries
from .settings import DashboardSettings from .settings import DashboardSettings
@ -69,6 +70,7 @@ class ESPHomeDashboard:
"mqtt_ping_request", "mqtt_ping_request",
"mdns_status", "mdns_status",
"settings", "settings",
"dns_cache",
) )
def __init__(self) -> None: def __init__(self) -> None:
@ -81,7 +83,8 @@ class ESPHomeDashboard:
self.ping_request: asyncio.Event | None = None self.ping_request: asyncio.Event | None = None
self.mqtt_ping_request = threading.Event() self.mqtt_ping_request = threading.Event()
self.mdns_status: MDNSStatus | None = None self.mdns_status: MDNSStatus | None = None
self.settings: DashboardSettings = DashboardSettings() self.settings = DashboardSettings()
self.dns_cache = DNSCache()
async def async_setup(self) -> None: async def async_setup(self) -> None:
"""Setup the dashboard.""" """Setup the dashboard."""

View file

@ -1,11 +1,19 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging
import os import os
import socket import socket
import threading
import traceback
from asyncio import events
from concurrent.futures import ThreadPoolExecutor
from time import monotonic
from typing import Any
from esphome.storage_json import EsphomeStorageJSON, esphome_storage_path from esphome.storage_json import EsphomeStorageJSON, esphome_storage_path
from .const import MAX_EXECUTOR_WORKERS
from .core import DASHBOARD from .core import DASHBOARD
from .web_server import make_app, start_web_server from .web_server import make_app, start_web_server
@ -14,6 +22,95 @@ ENV_DEV = "ESPHOME_DASHBOARD_DEV"
settings = DASHBOARD.settings settings = DASHBOARD.settings
def can_use_pidfd() -> bool:
"""Check if pidfd_open is available.
Back ported from cpython 3.12
"""
if not hasattr(os, "pidfd_open"):
return False
try:
pid = os.getpid()
os.close(os.pidfd_open(pid, 0))
except OSError:
# blocked by security policy like SECCOMP
return False
return True
class DashboardEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
"""Event loop policy for Home Assistant."""
def __init__(self, debug: bool) -> None:
"""Init the event loop policy."""
super().__init__()
self.debug = debug
self._watcher: asyncio.AbstractChildWatcher | None = None
def _init_watcher(self) -> None:
"""Initialize the watcher for child processes.
Back ported from cpython 3.12
"""
with events._lock: # type: ignore[attr-defined] # pylint: disable=protected-access
if self._watcher is None: # pragma: no branch
if can_use_pidfd():
self._watcher = asyncio.PidfdChildWatcher()
else:
self._watcher = asyncio.ThreadedChildWatcher()
if threading.current_thread() is threading.main_thread():
self._watcher.attach_loop(
self._local._loop # type: ignore[attr-defined] # pylint: disable=protected-access
)
@property
def loop_name(self) -> str:
"""Return name of the loop."""
return self._loop_factory.__name__ # type: ignore[no-any-return,attr-defined]
def new_event_loop(self) -> asyncio.AbstractEventLoop:
"""Get the event loop."""
loop: asyncio.AbstractEventLoop = super().new_event_loop()
loop.set_exception_handler(_async_loop_exception_handler)
if self.debug:
loop.set_debug(True)
executor = ThreadPoolExecutor(
thread_name_prefix="SyncWorker", max_workers=MAX_EXECUTOR_WORKERS
)
loop.set_default_executor(executor)
# bind the built-in time.monotonic directly as loop.time to avoid the
# overhead of the additional method call since its the most called loop
# method and its roughly 10%+ of all the call time in base_events.py
loop.time = monotonic # type: ignore[method-assign]
return loop
def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None:
"""Handle all exception inside the core loop."""
kwargs = {}
if exception := context.get("exception"):
kwargs["exc_info"] = (type(exception), exception, exception.__traceback__)
logger = logging.getLogger(__package__)
if source_traceback := context.get("source_traceback"):
stack_summary = "".join(traceback.format_list(source_traceback))
logger.error(
"Error doing job: %s: %s",
context["message"],
stack_summary,
**kwargs, # type: ignore[arg-type]
)
return
logger.error(
"Error doing job: %s",
context["message"],
**kwargs, # type: ignore[arg-type]
)
def start_dashboard(args) -> None: def start_dashboard(args) -> None:
"""Start the dashboard.""" """Start the dashboard."""
settings.parse_args(args) settings.parse_args(args)
@ -26,6 +123,8 @@ def start_dashboard(args) -> None:
storage.save(path) storage.save(path)
settings.cookie_secret = storage.cookie_secret settings.cookie_secret = storage.cookie_secret
asyncio.set_event_loop_policy(DashboardEventLoopPolicy(settings.verbose))
try: try:
asyncio.run(async_start(args)) asyncio.run(async_start(args))
except KeyboardInterrupt: except KeyboardInterrupt:

43
esphome/dashboard/dns.py Normal file
View file

@ -0,0 +1,43 @@
from __future__ import annotations
import asyncio
import sys
from icmplib import NameLookupError, async_resolve
if sys.version_info >= (3, 11):
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout
async def _async_resolve_wrapper(hostname: str) -> list[str] | Exception:
"""Wrap the icmplib async_resolve function."""
try:
async with async_timeout(2):
return await async_resolve(hostname)
except (asyncio.TimeoutError, NameLookupError, UnicodeError) as ex:
return ex
class DNSCache:
"""DNS cache for the dashboard."""
def __init__(self, ttl: int | None = 120) -> None:
"""Initialize the DNSCache."""
self._cache: dict[str, tuple[float, list[str] | Exception]] = {}
self._ttl = ttl
async def async_resolve(
self, hostname: str, now_monotonic: float
) -> list[str] | Exception:
"""Resolve a hostname to a list of IP address."""
if expire_time_addresses := self._cache.get(hostname):
expire_time, addresses = expire_time_addresses
if expire_time > now_monotonic:
return addresses
expires = now_monotonic + self._ttl
addresses = await _async_resolve_wrapper(hostname)
self._cache[hostname] = (expires, addresses)
return addresses

View file

@ -14,7 +14,19 @@ from .util.password import password_hash
class DashboardSettings: class DashboardSettings:
"""Settings for the dashboard.""" """Settings for the dashboard."""
__slots__ = (
"config_dir",
"password_hash",
"username",
"using_password",
"on_ha_addon",
"cookie_secret",
"absolute_config_dir",
"verbose",
)
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the dashboard settings."""
self.config_dir: str = "" self.config_dir: str = ""
self.password_hash: str = "" self.password_hash: str = ""
self.username: str = "" self.username: str = ""
@ -22,8 +34,10 @@ class DashboardSettings:
self.on_ha_addon: bool = False self.on_ha_addon: bool = False
self.cookie_secret: str | None = None self.cookie_secret: str | None = None
self.absolute_config_dir: Path | None = None self.absolute_config_dir: Path | None = None
self.verbose: bool = False
def parse_args(self, args: Any) -> None: def parse_args(self, args: Any) -> None:
"""Parse the arguments."""
self.on_ha_addon: bool = args.ha_addon self.on_ha_addon: bool = args.ha_addon
password = args.password or os.getenv("PASSWORD") or "" password = args.password or os.getenv("PASSWORD") or ""
if not self.on_ha_addon: if not self.on_ha_addon:
@ -33,6 +47,7 @@ class DashboardSettings:
self.password_hash = password_hash(password) self.password_hash = password_hash(password)
self.config_dir = args.configuration self.config_dir = args.configuration
self.absolute_config_dir = Path(self.config_dir).resolve() self.absolute_config_dir = Path(self.config_dir).resolve()
self.verbose = args.verbose
CORE.config_path = os.path.join(self.config_dir, ".") CORE.config_path = os.path.join(self.config_dir, ".")
@property @property

View file

@ -1,20 +1,20 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import os import logging
import time
from typing import cast from typing import cast
from icmplib import Host, SocketPermissionError, async_ping
from ..const import MAX_EXECUTOR_WORKERS
from ..core import DASHBOARD from ..core import DASHBOARD
from ..entries import DashboardEntry, bool_to_entry_state from ..entries import DashboardEntry, EntryState, bool_to_entry_state
from ..util.itertools import chunked from ..util.itertools import chunked
from ..util.subprocess import async_system_command_status
_LOGGER = logging.getLogger(__name__)
async def _async_ping_host(host: str) -> bool: GROUP_SIZE = int(MAX_EXECUTOR_WORKERS / 2)
"""Ping a host."""
return await async_system_command_status(
["ping", "-n" if os.name == "nt" else "-c", "1", host]
)
class PingStatus: class PingStatus:
@ -27,6 +27,10 @@ class PingStatus:
"""Run the ping status.""" """Run the ping status."""
dashboard = DASHBOARD dashboard = DASHBOARD
entries = dashboard.entries entries = dashboard.entries
privileged = await _can_use_icmp_lib_with_privilege()
if privileged is None:
_LOGGER.warning("Cannot use icmplib because privileges are insufficient")
return
while not dashboard.stop_event.is_set(): while not dashboard.stop_event.is_set():
# Only ping if the dashboard is open # Only ping if the dashboard is open
@ -36,15 +40,68 @@ class PingStatus:
to_ping: list[DashboardEntry] = [ to_ping: list[DashboardEntry] = [
entry for entry in current_entries if entry.address is not None entry for entry in current_entries if entry.address is not None
] ]
for ping_group in chunked(to_ping, 16):
# Resolve DNS for all entries
entries_with_addresses: dict[DashboardEntry, list[str]] = {}
for ping_group in chunked(to_ping, GROUP_SIZE):
ping_group = cast(list[DashboardEntry], ping_group) ping_group = cast(list[DashboardEntry], ping_group)
results = await asyncio.gather( now_monotonic = time.monotonic()
*(_async_ping_host(entry.address) for entry in ping_group), dns_results = await asyncio.gather(
*(
dashboard.dns_cache.async_resolve(entry.address, now_monotonic)
for entry in ping_group
),
return_exceptions=True, return_exceptions=True,
) )
for entry, result in zip(ping_group, results):
for entry, result in zip(ping_group, dns_results):
if isinstance(result, Exception): if isinstance(result, Exception):
result = False entries.async_set_state(entry, EntryState.UNKNOWN)
continue
if isinstance(result, BaseException):
raise result
entries_with_addresses[entry] = result
# Ping all entries with valid addresses
for ping_group in chunked(entries_with_addresses.items(), GROUP_SIZE):
entry_addresses = cast(tuple[DashboardEntry, list[str]], ping_group)
results = await asyncio.gather(
*(
async_ping(addresses[0], privileged=privileged)
for _, addresses in entry_addresses
),
return_exceptions=True,
)
for entry_addresses, result in zip(entry_addresses, results):
if isinstance(result, Exception):
ping_result = False
elif isinstance(result, BaseException): elif isinstance(result, BaseException):
raise result raise result
entries.async_set_state(entry, bool_to_entry_state(result)) else:
host: Host = result
ping_result = host.is_alive
entry, _ = entry_addresses
entries.async_set_state(entry, bool_to_entry_state(ping_result))
async def _can_use_icmp_lib_with_privilege() -> None | bool:
"""Verify we can create a raw socket."""
try:
await async_ping("127.0.0.1", count=0, timeout=0, privileged=True)
except SocketPermissionError:
try:
await async_ping("127.0.0.1", count=0, timeout=0, privileged=False)
except SocketPermissionError:
_LOGGER.debug(
"Cannot use icmplib because privileges are insufficient to create the"
" socket"
)
return None
_LOGGER.debug("Using icmplib in privileged=False mode")
return False
_LOGGER.debug("Using icmplib in privileged=True mode")
return True

View file

@ -9,6 +9,7 @@ import hashlib
import json import json
import logging import logging
import os import os
import time
import secrets import secrets
import shutil import shutil
import subprocess import subprocess
@ -302,16 +303,28 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
port = json_message["port"] port = json_message["port"]
if ( if (
port == "OTA" # pylint: disable=too-many-boolean-expressions port == "OTA" # pylint: disable=too-many-boolean-expressions
and (mdns := dashboard.mdns_status)
and (entry := entries.get(config_file)) and (entry := entries.get(config_file))
and entry.loaded_integrations and entry.loaded_integrations
and "api" in entry.loaded_integrations and "api" in entry.loaded_integrations
and (address := await mdns.async_resolve_host(entry.name)) ):
if (mdns := dashboard.mdns_status) and (
address := await mdns.async_resolve_host(entry.name)
): ):
# Use the IP address if available but only # Use the IP address if available but only
# if the API is loaded and the device is online # if the API is loaded and the device is online
# since MQTT logging will not work otherwise # since MQTT logging will not work otherwise
port = address port = address
elif (
entry.address
and (
address_list := await dashboard.dns_cache.async_resolve(
entry.address, time.monotonic()
)
)
and not isinstance(address_list, Exception)
):
# If mdns is not available, try to use the DNS cache
port = address_list[0]
return [ return [
*DASHBOARD_COMMAND, *DASHBOARD_COMMAND,

View file

@ -1,7 +1,9 @@
async_timeout==4.0.3; python_version <= "3.10"
voluptuous==0.14.1 voluptuous==0.14.1
PyYAML==6.0.1 PyYAML==6.0.1
paho-mqtt==1.6.1 paho-mqtt==1.6.1
colorama==0.4.6 colorama==0.4.6
icmplib==3.0.4
tornado==6.4 tornado==6.4
tzlocal==5.2 # from time tzlocal==5.2 # from time
tzdata>=2021.1 # from time tzdata>=2021.1 # from time