mirror of
https://github.com/esphome/esphome.git
synced 2024-12-22 05:24:53 +01:00
dashboard: refactor ping implementation to be more efficient (#6002)
This commit is contained in:
parent
87301a2e76
commit
6dfdcff66c
8 changed files with 255 additions and 21 deletions
|
@ -4,5 +4,7 @@ EVENT_ENTRY_ADDED = "entry_added"
|
|||
EVENT_ENTRY_REMOVED = "entry_removed"
|
||||
EVENT_ENTRY_UPDATED = "entry_updated"
|
||||
EVENT_ENTRY_STATE_CHANGED = "entry_state_changed"
|
||||
MAX_EXECUTOR_WORKERS = 48
|
||||
|
||||
|
||||
SENTINEL = object()
|
||||
|
|
|
@ -8,6 +8,7 @@ from functools import partial
|
|||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from ..zeroconf import DiscoveredImport
|
||||
from .dns import DNSCache
|
||||
from .entries import DashboardEntries
|
||||
from .settings import DashboardSettings
|
||||
|
||||
|
@ -69,6 +70,7 @@ class ESPHomeDashboard:
|
|||
"mqtt_ping_request",
|
||||
"mdns_status",
|
||||
"settings",
|
||||
"dns_cache",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -81,7 +83,8 @@ class ESPHomeDashboard:
|
|||
self.ping_request: asyncio.Event | None = None
|
||||
self.mqtt_ping_request = threading.Event()
|
||||
self.mdns_status: MDNSStatus | None = None
|
||||
self.settings: DashboardSettings = DashboardSettings()
|
||||
self.settings = DashboardSettings()
|
||||
self.dns_cache = DNSCache()
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
"""Setup the dashboard."""
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import traceback
|
||||
from asyncio import events
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from time import monotonic
|
||||
from typing import Any
|
||||
|
||||
from esphome.storage_json import EsphomeStorageJSON, esphome_storage_path
|
||||
|
||||
from .const import MAX_EXECUTOR_WORKERS
|
||||
from .core import DASHBOARD
|
||||
from .web_server import make_app, start_web_server
|
||||
|
||||
|
@ -14,6 +22,95 @@ ENV_DEV = "ESPHOME_DASHBOARD_DEV"
|
|||
settings = DASHBOARD.settings
|
||||
|
||||
|
||||
def can_use_pidfd() -> bool:
|
||||
"""Check if pidfd_open is available.
|
||||
|
||||
Back ported from cpython 3.12
|
||||
"""
|
||||
if not hasattr(os, "pidfd_open"):
|
||||
return False
|
||||
try:
|
||||
pid = os.getpid()
|
||||
os.close(os.pidfd_open(pid, 0))
|
||||
except OSError:
|
||||
# blocked by security policy like SECCOMP
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class DashboardEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
|
||||
"""Event loop policy for Home Assistant."""
|
||||
|
||||
def __init__(self, debug: bool) -> None:
|
||||
"""Init the event loop policy."""
|
||||
super().__init__()
|
||||
self.debug = debug
|
||||
self._watcher: asyncio.AbstractChildWatcher | None = None
|
||||
|
||||
def _init_watcher(self) -> None:
|
||||
"""Initialize the watcher for child processes.
|
||||
|
||||
Back ported from cpython 3.12
|
||||
"""
|
||||
with events._lock: # type: ignore[attr-defined] # pylint: disable=protected-access
|
||||
if self._watcher is None: # pragma: no branch
|
||||
if can_use_pidfd():
|
||||
self._watcher = asyncio.PidfdChildWatcher()
|
||||
else:
|
||||
self._watcher = asyncio.ThreadedChildWatcher()
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
self._watcher.attach_loop(
|
||||
self._local._loop # type: ignore[attr-defined] # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
@property
|
||||
def loop_name(self) -> str:
|
||||
"""Return name of the loop."""
|
||||
return self._loop_factory.__name__ # type: ignore[no-any-return,attr-defined]
|
||||
|
||||
def new_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
"""Get the event loop."""
|
||||
loop: asyncio.AbstractEventLoop = super().new_event_loop()
|
||||
loop.set_exception_handler(_async_loop_exception_handler)
|
||||
|
||||
if self.debug:
|
||||
loop.set_debug(True)
|
||||
|
||||
executor = ThreadPoolExecutor(
|
||||
thread_name_prefix="SyncWorker", max_workers=MAX_EXECUTOR_WORKERS
|
||||
)
|
||||
loop.set_default_executor(executor)
|
||||
# bind the built-in time.monotonic directly as loop.time to avoid the
|
||||
# overhead of the additional method call since its the most called loop
|
||||
# method and its roughly 10%+ of all the call time in base_events.py
|
||||
loop.time = monotonic # type: ignore[method-assign]
|
||||
return loop
|
||||
|
||||
|
||||
def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None:
|
||||
"""Handle all exception inside the core loop."""
|
||||
kwargs = {}
|
||||
if exception := context.get("exception"):
|
||||
kwargs["exc_info"] = (type(exception), exception, exception.__traceback__)
|
||||
|
||||
logger = logging.getLogger(__package__)
|
||||
if source_traceback := context.get("source_traceback"):
|
||||
stack_summary = "".join(traceback.format_list(source_traceback))
|
||||
logger.error(
|
||||
"Error doing job: %s: %s",
|
||||
context["message"],
|
||||
stack_summary,
|
||||
**kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
return
|
||||
|
||||
logger.error(
|
||||
"Error doing job: %s",
|
||||
context["message"],
|
||||
**kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def start_dashboard(args) -> None:
|
||||
"""Start the dashboard."""
|
||||
settings.parse_args(args)
|
||||
|
@ -26,6 +123,8 @@ def start_dashboard(args) -> None:
|
|||
storage.save(path)
|
||||
settings.cookie_secret = storage.cookie_secret
|
||||
|
||||
asyncio.set_event_loop_policy(DashboardEventLoopPolicy(settings.verbose))
|
||||
|
||||
try:
|
||||
asyncio.run(async_start(args))
|
||||
except KeyboardInterrupt:
|
||||
|
|
43
esphome/dashboard/dns.py
Normal file
43
esphome/dashboard/dns.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from icmplib import NameLookupError, async_resolve
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from asyncio import timeout as async_timeout
|
||||
else:
|
||||
from async_timeout import timeout as async_timeout
|
||||
|
||||
|
||||
async def _async_resolve_wrapper(hostname: str) -> list[str] | Exception:
|
||||
"""Wrap the icmplib async_resolve function."""
|
||||
try:
|
||||
async with async_timeout(2):
|
||||
return await async_resolve(hostname)
|
||||
except (asyncio.TimeoutError, NameLookupError, UnicodeError) as ex:
|
||||
return ex
|
||||
|
||||
|
||||
class DNSCache:
|
||||
"""DNS cache for the dashboard."""
|
||||
|
||||
def __init__(self, ttl: int | None = 120) -> None:
|
||||
"""Initialize the DNSCache."""
|
||||
self._cache: dict[str, tuple[float, list[str] | Exception]] = {}
|
||||
self._ttl = ttl
|
||||
|
||||
async def async_resolve(
|
||||
self, hostname: str, now_monotonic: float
|
||||
) -> list[str] | Exception:
|
||||
"""Resolve a hostname to a list of IP address."""
|
||||
if expire_time_addresses := self._cache.get(hostname):
|
||||
expire_time, addresses = expire_time_addresses
|
||||
if expire_time > now_monotonic:
|
||||
return addresses
|
||||
|
||||
expires = now_monotonic + self._ttl
|
||||
addresses = await _async_resolve_wrapper(hostname)
|
||||
self._cache[hostname] = (expires, addresses)
|
||||
return addresses
|
|
@ -14,7 +14,19 @@ from .util.password import password_hash
|
|||
class DashboardSettings:
|
||||
"""Settings for the dashboard."""
|
||||
|
||||
__slots__ = (
|
||||
"config_dir",
|
||||
"password_hash",
|
||||
"username",
|
||||
"using_password",
|
||||
"on_ha_addon",
|
||||
"cookie_secret",
|
||||
"absolute_config_dir",
|
||||
"verbose",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the dashboard settings."""
|
||||
self.config_dir: str = ""
|
||||
self.password_hash: str = ""
|
||||
self.username: str = ""
|
||||
|
@ -22,8 +34,10 @@ class DashboardSettings:
|
|||
self.on_ha_addon: bool = False
|
||||
self.cookie_secret: str | None = None
|
||||
self.absolute_config_dir: Path | None = None
|
||||
self.verbose: bool = False
|
||||
|
||||
def parse_args(self, args: Any) -> None:
|
||||
"""Parse the arguments."""
|
||||
self.on_ha_addon: bool = args.ha_addon
|
||||
password = args.password or os.getenv("PASSWORD") or ""
|
||||
if not self.on_ha_addon:
|
||||
|
@ -33,6 +47,7 @@ class DashboardSettings:
|
|||
self.password_hash = password_hash(password)
|
||||
self.config_dir = args.configuration
|
||||
self.absolute_config_dir = Path(self.config_dir).resolve()
|
||||
self.verbose = args.verbose
|
||||
CORE.config_path = os.path.join(self.config_dir, ".")
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from icmplib import Host, SocketPermissionError, async_ping
|
||||
|
||||
from ..const import MAX_EXECUTOR_WORKERS
|
||||
from ..core import DASHBOARD
|
||||
from ..entries import DashboardEntry, bool_to_entry_state
|
||||
from ..entries import DashboardEntry, EntryState, bool_to_entry_state
|
||||
from ..util.itertools import chunked
|
||||
from ..util.subprocess import async_system_command_status
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
async def _async_ping_host(host: str) -> bool:
|
||||
"""Ping a host."""
|
||||
return await async_system_command_status(
|
||||
["ping", "-n" if os.name == "nt" else "-c", "1", host]
|
||||
)
|
||||
GROUP_SIZE = int(MAX_EXECUTOR_WORKERS / 2)
|
||||
|
||||
|
||||
class PingStatus:
|
||||
|
@ -27,6 +27,10 @@ class PingStatus:
|
|||
"""Run the ping status."""
|
||||
dashboard = DASHBOARD
|
||||
entries = dashboard.entries
|
||||
privileged = await _can_use_icmp_lib_with_privilege()
|
||||
if privileged is None:
|
||||
_LOGGER.warning("Cannot use icmplib because privileges are insufficient")
|
||||
return
|
||||
|
||||
while not dashboard.stop_event.is_set():
|
||||
# Only ping if the dashboard is open
|
||||
|
@ -36,15 +40,68 @@ class PingStatus:
|
|||
to_ping: list[DashboardEntry] = [
|
||||
entry for entry in current_entries if entry.address is not None
|
||||
]
|
||||
for ping_group in chunked(to_ping, 16):
|
||||
|
||||
# Resolve DNS for all entries
|
||||
entries_with_addresses: dict[DashboardEntry, list[str]] = {}
|
||||
for ping_group in chunked(to_ping, GROUP_SIZE):
|
||||
ping_group = cast(list[DashboardEntry], ping_group)
|
||||
results = await asyncio.gather(
|
||||
*(_async_ping_host(entry.address) for entry in ping_group),
|
||||
now_monotonic = time.monotonic()
|
||||
dns_results = await asyncio.gather(
|
||||
*(
|
||||
dashboard.dns_cache.async_resolve(entry.address, now_monotonic)
|
||||
for entry in ping_group
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for entry, result in zip(ping_group, results):
|
||||
|
||||
for entry, result in zip(ping_group, dns_results):
|
||||
if isinstance(result, Exception):
|
||||
result = False
|
||||
entries.async_set_state(entry, EntryState.UNKNOWN)
|
||||
continue
|
||||
if isinstance(result, BaseException):
|
||||
raise result
|
||||
entries_with_addresses[entry] = result
|
||||
|
||||
# Ping all entries with valid addresses
|
||||
for ping_group in chunked(entries_with_addresses.items(), GROUP_SIZE):
|
||||
entry_addresses = cast(tuple[DashboardEntry, list[str]], ping_group)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
async_ping(addresses[0], privileged=privileged)
|
||||
for _, addresses in entry_addresses
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
for entry_addresses, result in zip(entry_addresses, results):
|
||||
if isinstance(result, Exception):
|
||||
ping_result = False
|
||||
elif isinstance(result, BaseException):
|
||||
raise result
|
||||
entries.async_set_state(entry, bool_to_entry_state(result))
|
||||
else:
|
||||
host: Host = result
|
||||
ping_result = host.is_alive
|
||||
entry, _ = entry_addresses
|
||||
entries.async_set_state(entry, bool_to_entry_state(ping_result))
|
||||
|
||||
|
||||
async def _can_use_icmp_lib_with_privilege() -> None | bool:
|
||||
"""Verify we can create a raw socket."""
|
||||
try:
|
||||
await async_ping("127.0.0.1", count=0, timeout=0, privileged=True)
|
||||
except SocketPermissionError:
|
||||
try:
|
||||
await async_ping("127.0.0.1", count=0, timeout=0, privileged=False)
|
||||
except SocketPermissionError:
|
||||
_LOGGER.debug(
|
||||
"Cannot use icmplib because privileges are insufficient to create the"
|
||||
" socket"
|
||||
)
|
||||
return None
|
||||
|
||||
_LOGGER.debug("Using icmplib in privileged=False mode")
|
||||
return False
|
||||
|
||||
_LOGGER.debug("Using icmplib in privileged=True mode")
|
||||
return True
|
||||
|
|
|
@ -9,6 +9,7 @@ import hashlib
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import secrets
|
||||
import shutil
|
||||
import subprocess
|
||||
|
@ -302,16 +303,28 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
|
|||
port = json_message["port"]
|
||||
if (
|
||||
port == "OTA" # pylint: disable=too-many-boolean-expressions
|
||||
and (mdns := dashboard.mdns_status)
|
||||
and (entry := entries.get(config_file))
|
||||
and entry.loaded_integrations
|
||||
and "api" in entry.loaded_integrations
|
||||
and (address := await mdns.async_resolve_host(entry.name))
|
||||
):
|
||||
# Use the IP address if available but only
|
||||
# if the API is loaded and the device is online
|
||||
# since MQTT logging will not work otherwise
|
||||
port = address
|
||||
if (mdns := dashboard.mdns_status) and (
|
||||
address := await mdns.async_resolve_host(entry.name)
|
||||
):
|
||||
# Use the IP address if available but only
|
||||
# if the API is loaded and the device is online
|
||||
# since MQTT logging will not work otherwise
|
||||
port = address
|
||||
elif (
|
||||
entry.address
|
||||
and (
|
||||
address_list := await dashboard.dns_cache.async_resolve(
|
||||
entry.address, time.monotonic()
|
||||
)
|
||||
)
|
||||
and not isinstance(address_list, Exception)
|
||||
):
|
||||
# If mdns is not available, try to use the DNS cache
|
||||
port = address_list[0]
|
||||
|
||||
return [
|
||||
*DASHBOARD_COMMAND,
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
async_timeout==4.0.3; python_version <= "3.10"
|
||||
voluptuous==0.14.1
|
||||
PyYAML==6.0.1
|
||||
paho-mqtt==1.6.1
|
||||
colorama==0.4.6
|
||||
icmplib==3.0.4
|
||||
tornado==6.4
|
||||
tzlocal==5.2 # from time
|
||||
tzdata>=2021.1 # from time
|
||||
|
|
Loading…
Reference in a new issue