mirror of
https://github.com/esphome/esphome.git
synced 2024-11-14 02:58:11 +01:00
dashboard: refactor ping implementation to be more efficient (#6002)
This commit is contained in:
parent
87301a2e76
commit
6dfdcff66c
8 changed files with 255 additions and 21 deletions
|
@ -4,5 +4,7 @@ EVENT_ENTRY_ADDED = "entry_added"
|
||||||
EVENT_ENTRY_REMOVED = "entry_removed"
|
EVENT_ENTRY_REMOVED = "entry_removed"
|
||||||
EVENT_ENTRY_UPDATED = "entry_updated"
|
EVENT_ENTRY_UPDATED = "entry_updated"
|
||||||
EVENT_ENTRY_STATE_CHANGED = "entry_state_changed"
|
EVENT_ENTRY_STATE_CHANGED = "entry_state_changed"
|
||||||
|
MAX_EXECUTOR_WORKERS = 48
|
||||||
|
|
||||||
|
|
||||||
SENTINEL = object()
|
SENTINEL = object()
|
||||||
|
|
|
@ -8,6 +8,7 @@ from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
from ..zeroconf import DiscoveredImport
|
from ..zeroconf import DiscoveredImport
|
||||||
|
from .dns import DNSCache
|
||||||
from .entries import DashboardEntries
|
from .entries import DashboardEntries
|
||||||
from .settings import DashboardSettings
|
from .settings import DashboardSettings
|
||||||
|
|
||||||
|
@ -69,6 +70,7 @@ class ESPHomeDashboard:
|
||||||
"mqtt_ping_request",
|
"mqtt_ping_request",
|
||||||
"mdns_status",
|
"mdns_status",
|
||||||
"settings",
|
"settings",
|
||||||
|
"dns_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -81,7 +83,8 @@ class ESPHomeDashboard:
|
||||||
self.ping_request: asyncio.Event | None = None
|
self.ping_request: asyncio.Event | None = None
|
||||||
self.mqtt_ping_request = threading.Event()
|
self.mqtt_ping_request = threading.Event()
|
||||||
self.mdns_status: MDNSStatus | None = None
|
self.mdns_status: MDNSStatus | None = None
|
||||||
self.settings: DashboardSettings = DashboardSettings()
|
self.settings = DashboardSettings()
|
||||||
|
self.dns_cache = DNSCache()
|
||||||
|
|
||||||
async def async_setup(self) -> None:
|
async def async_setup(self) -> None:
|
||||||
"""Setup the dashboard."""
|
"""Setup the dashboard."""
|
||||||
|
|
|
@ -1,11 +1,19 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
from asyncio import events
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from time import monotonic
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from esphome.storage_json import EsphomeStorageJSON, esphome_storage_path
|
from esphome.storage_json import EsphomeStorageJSON, esphome_storage_path
|
||||||
|
|
||||||
|
from .const import MAX_EXECUTOR_WORKERS
|
||||||
from .core import DASHBOARD
|
from .core import DASHBOARD
|
||||||
from .web_server import make_app, start_web_server
|
from .web_server import make_app, start_web_server
|
||||||
|
|
||||||
|
@ -14,6 +22,95 @@ ENV_DEV = "ESPHOME_DASHBOARD_DEV"
|
||||||
settings = DASHBOARD.settings
|
settings = DASHBOARD.settings
|
||||||
|
|
||||||
|
|
||||||
|
def can_use_pidfd() -> bool:
|
||||||
|
"""Check if pidfd_open is available.
|
||||||
|
|
||||||
|
Back ported from cpython 3.12
|
||||||
|
"""
|
||||||
|
if not hasattr(os, "pidfd_open"):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
pid = os.getpid()
|
||||||
|
os.close(os.pidfd_open(pid, 0))
|
||||||
|
except OSError:
|
||||||
|
# blocked by security policy like SECCOMP
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class DashboardEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
|
||||||
|
"""Event loop policy for Home Assistant."""
|
||||||
|
|
||||||
|
def __init__(self, debug: bool) -> None:
|
||||||
|
"""Init the event loop policy."""
|
||||||
|
super().__init__()
|
||||||
|
self.debug = debug
|
||||||
|
self._watcher: asyncio.AbstractChildWatcher | None = None
|
||||||
|
|
||||||
|
def _init_watcher(self) -> None:
|
||||||
|
"""Initialize the watcher for child processes.
|
||||||
|
|
||||||
|
Back ported from cpython 3.12
|
||||||
|
"""
|
||||||
|
with events._lock: # type: ignore[attr-defined] # pylint: disable=protected-access
|
||||||
|
if self._watcher is None: # pragma: no branch
|
||||||
|
if can_use_pidfd():
|
||||||
|
self._watcher = asyncio.PidfdChildWatcher()
|
||||||
|
else:
|
||||||
|
self._watcher = asyncio.ThreadedChildWatcher()
|
||||||
|
if threading.current_thread() is threading.main_thread():
|
||||||
|
self._watcher.attach_loop(
|
||||||
|
self._local._loop # type: ignore[attr-defined] # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop_name(self) -> str:
|
||||||
|
"""Return name of the loop."""
|
||||||
|
return self._loop_factory.__name__ # type: ignore[no-any-return,attr-defined]
|
||||||
|
|
||||||
|
def new_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||||
|
"""Get the event loop."""
|
||||||
|
loop: asyncio.AbstractEventLoop = super().new_event_loop()
|
||||||
|
loop.set_exception_handler(_async_loop_exception_handler)
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
loop.set_debug(True)
|
||||||
|
|
||||||
|
executor = ThreadPoolExecutor(
|
||||||
|
thread_name_prefix="SyncWorker", max_workers=MAX_EXECUTOR_WORKERS
|
||||||
|
)
|
||||||
|
loop.set_default_executor(executor)
|
||||||
|
# bind the built-in time.monotonic directly as loop.time to avoid the
|
||||||
|
# overhead of the additional method call since its the most called loop
|
||||||
|
# method and its roughly 10%+ of all the call time in base_events.py
|
||||||
|
loop.time = monotonic # type: ignore[method-assign]
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None:
|
||||||
|
"""Handle all exception inside the core loop."""
|
||||||
|
kwargs = {}
|
||||||
|
if exception := context.get("exception"):
|
||||||
|
kwargs["exc_info"] = (type(exception), exception, exception.__traceback__)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__package__)
|
||||||
|
if source_traceback := context.get("source_traceback"):
|
||||||
|
stack_summary = "".join(traceback.format_list(source_traceback))
|
||||||
|
logger.error(
|
||||||
|
"Error doing job: %s: %s",
|
||||||
|
context["message"],
|
||||||
|
stack_summary,
|
||||||
|
**kwargs, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
"Error doing job: %s",
|
||||||
|
context["message"],
|
||||||
|
**kwargs, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def start_dashboard(args) -> None:
|
def start_dashboard(args) -> None:
|
||||||
"""Start the dashboard."""
|
"""Start the dashboard."""
|
||||||
settings.parse_args(args)
|
settings.parse_args(args)
|
||||||
|
@ -26,6 +123,8 @@ def start_dashboard(args) -> None:
|
||||||
storage.save(path)
|
storage.save(path)
|
||||||
settings.cookie_secret = storage.cookie_secret
|
settings.cookie_secret = storage.cookie_secret
|
||||||
|
|
||||||
|
asyncio.set_event_loop_policy(DashboardEventLoopPolicy(settings.verbose))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
asyncio.run(async_start(args))
|
asyncio.run(async_start(args))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
43
esphome/dashboard/dns.py
Normal file
43
esphome/dashboard/dns.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from icmplib import NameLookupError, async_resolve
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from asyncio import timeout as async_timeout
|
||||||
|
else:
|
||||||
|
from async_timeout import timeout as async_timeout
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_resolve_wrapper(hostname: str) -> list[str] | Exception:
|
||||||
|
"""Wrap the icmplib async_resolve function."""
|
||||||
|
try:
|
||||||
|
async with async_timeout(2):
|
||||||
|
return await async_resolve(hostname)
|
||||||
|
except (asyncio.TimeoutError, NameLookupError, UnicodeError) as ex:
|
||||||
|
return ex
|
||||||
|
|
||||||
|
|
||||||
|
class DNSCache:
|
||||||
|
"""DNS cache for the dashboard."""
|
||||||
|
|
||||||
|
def __init__(self, ttl: int | None = 120) -> None:
|
||||||
|
"""Initialize the DNSCache."""
|
||||||
|
self._cache: dict[str, tuple[float, list[str] | Exception]] = {}
|
||||||
|
self._ttl = ttl
|
||||||
|
|
||||||
|
async def async_resolve(
|
||||||
|
self, hostname: str, now_monotonic: float
|
||||||
|
) -> list[str] | Exception:
|
||||||
|
"""Resolve a hostname to a list of IP address."""
|
||||||
|
if expire_time_addresses := self._cache.get(hostname):
|
||||||
|
expire_time, addresses = expire_time_addresses
|
||||||
|
if expire_time > now_monotonic:
|
||||||
|
return addresses
|
||||||
|
|
||||||
|
expires = now_monotonic + self._ttl
|
||||||
|
addresses = await _async_resolve_wrapper(hostname)
|
||||||
|
self._cache[hostname] = (expires, addresses)
|
||||||
|
return addresses
|
|
@ -14,7 +14,19 @@ from .util.password import password_hash
|
||||||
class DashboardSettings:
|
class DashboardSettings:
|
||||||
"""Settings for the dashboard."""
|
"""Settings for the dashboard."""
|
||||||
|
|
||||||
|
__slots__ = (
|
||||||
|
"config_dir",
|
||||||
|
"password_hash",
|
||||||
|
"username",
|
||||||
|
"using_password",
|
||||||
|
"on_ha_addon",
|
||||||
|
"cookie_secret",
|
||||||
|
"absolute_config_dir",
|
||||||
|
"verbose",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the dashboard settings."""
|
||||||
self.config_dir: str = ""
|
self.config_dir: str = ""
|
||||||
self.password_hash: str = ""
|
self.password_hash: str = ""
|
||||||
self.username: str = ""
|
self.username: str = ""
|
||||||
|
@ -22,8 +34,10 @@ class DashboardSettings:
|
||||||
self.on_ha_addon: bool = False
|
self.on_ha_addon: bool = False
|
||||||
self.cookie_secret: str | None = None
|
self.cookie_secret: str | None = None
|
||||||
self.absolute_config_dir: Path | None = None
|
self.absolute_config_dir: Path | None = None
|
||||||
|
self.verbose: bool = False
|
||||||
|
|
||||||
def parse_args(self, args: Any) -> None:
|
def parse_args(self, args: Any) -> None:
|
||||||
|
"""Parse the arguments."""
|
||||||
self.on_ha_addon: bool = args.ha_addon
|
self.on_ha_addon: bool = args.ha_addon
|
||||||
password = args.password or os.getenv("PASSWORD") or ""
|
password = args.password or os.getenv("PASSWORD") or ""
|
||||||
if not self.on_ha_addon:
|
if not self.on_ha_addon:
|
||||||
|
@ -33,6 +47,7 @@ class DashboardSettings:
|
||||||
self.password_hash = password_hash(password)
|
self.password_hash = password_hash(password)
|
||||||
self.config_dir = args.configuration
|
self.config_dir = args.configuration
|
||||||
self.absolute_config_dir = Path(self.config_dir).resolve()
|
self.absolute_config_dir = Path(self.config_dir).resolve()
|
||||||
|
self.verbose = args.verbose
|
||||||
CORE.config_path = os.path.join(self.config_dir, ".")
|
CORE.config_path = os.path.join(self.config_dir, ".")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -1,20 +1,20 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import logging
|
||||||
|
import time
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
from icmplib import Host, SocketPermissionError, async_ping
|
||||||
|
|
||||||
|
from ..const import MAX_EXECUTOR_WORKERS
|
||||||
from ..core import DASHBOARD
|
from ..core import DASHBOARD
|
||||||
from ..entries import DashboardEntry, bool_to_entry_state
|
from ..entries import DashboardEntry, EntryState, bool_to_entry_state
|
||||||
from ..util.itertools import chunked
|
from ..util.itertools import chunked
|
||||||
from ..util.subprocess import async_system_command_status
|
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def _async_ping_host(host: str) -> bool:
|
GROUP_SIZE = int(MAX_EXECUTOR_WORKERS / 2)
|
||||||
"""Ping a host."""
|
|
||||||
return await async_system_command_status(
|
|
||||||
["ping", "-n" if os.name == "nt" else "-c", "1", host]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PingStatus:
|
class PingStatus:
|
||||||
|
@ -27,6 +27,10 @@ class PingStatus:
|
||||||
"""Run the ping status."""
|
"""Run the ping status."""
|
||||||
dashboard = DASHBOARD
|
dashboard = DASHBOARD
|
||||||
entries = dashboard.entries
|
entries = dashboard.entries
|
||||||
|
privileged = await _can_use_icmp_lib_with_privilege()
|
||||||
|
if privileged is None:
|
||||||
|
_LOGGER.warning("Cannot use icmplib because privileges are insufficient")
|
||||||
|
return
|
||||||
|
|
||||||
while not dashboard.stop_event.is_set():
|
while not dashboard.stop_event.is_set():
|
||||||
# Only ping if the dashboard is open
|
# Only ping if the dashboard is open
|
||||||
|
@ -36,15 +40,68 @@ class PingStatus:
|
||||||
to_ping: list[DashboardEntry] = [
|
to_ping: list[DashboardEntry] = [
|
||||||
entry for entry in current_entries if entry.address is not None
|
entry for entry in current_entries if entry.address is not None
|
||||||
]
|
]
|
||||||
for ping_group in chunked(to_ping, 16):
|
|
||||||
|
# Resolve DNS for all entries
|
||||||
|
entries_with_addresses: dict[DashboardEntry, list[str]] = {}
|
||||||
|
for ping_group in chunked(to_ping, GROUP_SIZE):
|
||||||
ping_group = cast(list[DashboardEntry], ping_group)
|
ping_group = cast(list[DashboardEntry], ping_group)
|
||||||
results = await asyncio.gather(
|
now_monotonic = time.monotonic()
|
||||||
*(_async_ping_host(entry.address) for entry in ping_group),
|
dns_results = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
dashboard.dns_cache.async_resolve(entry.address, now_monotonic)
|
||||||
|
for entry in ping_group
|
||||||
|
),
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
for entry, result in zip(ping_group, results):
|
|
||||||
|
for entry, result in zip(ping_group, dns_results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
result = False
|
entries.async_set_state(entry, EntryState.UNKNOWN)
|
||||||
|
continue
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
raise result
|
||||||
|
entries_with_addresses[entry] = result
|
||||||
|
|
||||||
|
# Ping all entries with valid addresses
|
||||||
|
for ping_group in chunked(entries_with_addresses.items(), GROUP_SIZE):
|
||||||
|
entry_addresses = cast(tuple[DashboardEntry, list[str]], ping_group)
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
async_ping(addresses[0], privileged=privileged)
|
||||||
|
for _, addresses in entry_addresses
|
||||||
|
),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for entry_addresses, result in zip(entry_addresses, results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
ping_result = False
|
||||||
elif isinstance(result, BaseException):
|
elif isinstance(result, BaseException):
|
||||||
raise result
|
raise result
|
||||||
entries.async_set_state(entry, bool_to_entry_state(result))
|
else:
|
||||||
|
host: Host = result
|
||||||
|
ping_result = host.is_alive
|
||||||
|
entry, _ = entry_addresses
|
||||||
|
entries.async_set_state(entry, bool_to_entry_state(ping_result))
|
||||||
|
|
||||||
|
|
||||||
|
async def _can_use_icmp_lib_with_privilege() -> None | bool:
|
||||||
|
"""Verify we can create a raw socket."""
|
||||||
|
try:
|
||||||
|
await async_ping("127.0.0.1", count=0, timeout=0, privileged=True)
|
||||||
|
except SocketPermissionError:
|
||||||
|
try:
|
||||||
|
await async_ping("127.0.0.1", count=0, timeout=0, privileged=False)
|
||||||
|
except SocketPermissionError:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Cannot use icmplib because privileges are insufficient to create the"
|
||||||
|
" socket"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
_LOGGER.debug("Using icmplib in privileged=False mode")
|
||||||
|
return False
|
||||||
|
|
||||||
|
_LOGGER.debug("Using icmplib in privileged=True mode")
|
||||||
|
return True
|
||||||
|
|
|
@ -9,6 +9,7 @@ import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import secrets
|
import secrets
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -302,16 +303,28 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
|
||||||
port = json_message["port"]
|
port = json_message["port"]
|
||||||
if (
|
if (
|
||||||
port == "OTA" # pylint: disable=too-many-boolean-expressions
|
port == "OTA" # pylint: disable=too-many-boolean-expressions
|
||||||
and (mdns := dashboard.mdns_status)
|
|
||||||
and (entry := entries.get(config_file))
|
and (entry := entries.get(config_file))
|
||||||
and entry.loaded_integrations
|
and entry.loaded_integrations
|
||||||
and "api" in entry.loaded_integrations
|
and "api" in entry.loaded_integrations
|
||||||
and (address := await mdns.async_resolve_host(entry.name))
|
|
||||||
):
|
):
|
||||||
# Use the IP address if available but only
|
if (mdns := dashboard.mdns_status) and (
|
||||||
# if the API is loaded and the device is online
|
address := await mdns.async_resolve_host(entry.name)
|
||||||
# since MQTT logging will not work otherwise
|
):
|
||||||
port = address
|
# Use the IP address if available but only
|
||||||
|
# if the API is loaded and the device is online
|
||||||
|
# since MQTT logging will not work otherwise
|
||||||
|
port = address
|
||||||
|
elif (
|
||||||
|
entry.address
|
||||||
|
and (
|
||||||
|
address_list := await dashboard.dns_cache.async_resolve(
|
||||||
|
entry.address, time.monotonic()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
and not isinstance(address_list, Exception)
|
||||||
|
):
|
||||||
|
# If mdns is not available, try to use the DNS cache
|
||||||
|
port = address_list[0]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
*DASHBOARD_COMMAND,
|
*DASHBOARD_COMMAND,
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
async_timeout==4.0.3; python_version <= "3.10"
|
||||||
voluptuous==0.14.1
|
voluptuous==0.14.1
|
||||||
PyYAML==6.0.1
|
PyYAML==6.0.1
|
||||||
paho-mqtt==1.6.1
|
paho-mqtt==1.6.1
|
||||||
colorama==0.4.6
|
colorama==0.4.6
|
||||||
|
icmplib==3.0.4
|
||||||
tornado==6.4
|
tornado==6.4
|
||||||
tzlocal==5.2 # from time
|
tzlocal==5.2 # from time
|
||||||
tzdata>=2021.1 # from time
|
tzdata>=2021.1 # from time
|
||||||
|
|
Loading…
Reference in a new issue