From 20ea8bf06e4414a038df087c210a47e94cefab23 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 14 Nov 2023 22:55:33 -0600 Subject: [PATCH] dashboard: convert ping thread to use asyncio (#5749) --- esphome/dashboard/async_adapter.py | 29 +------- esphome/dashboard/dashboard.py | 112 +++++++++++++---------------- esphome/dashboard/util.py | 20 ++++++ 3 files changed, 73 insertions(+), 88 deletions(-) diff --git a/esphome/dashboard/async_adapter.py b/esphome/dashboard/async_adapter.py index d6f4f6e1ff..44d2f42ce0 100644 --- a/esphome/dashboard/async_adapter.py +++ b/esphome/dashboard/async_adapter.py @@ -1,18 +1,13 @@ from __future__ import annotations import asyncio -import threading -class ThreadedAsyncEvent: - """This is a shim to allow the asyncio event to be used in a threaded context. - - When more of the code is moved to asyncio, this can be removed. - """ +class AsyncEvent: + """This is a shim around asyncio.Event.""" def __init__(self) -> None: """Initialize the ThreadedAsyncEvent.""" - self.event = threading.Event() self.async_event: asyncio.Event | None = None self.loop: asyncio.AbstractEventLoop | None = None @@ -26,31 +21,11 @@ class ThreadedAsyncEvent: def async_set(self) -> None: """Set the asyncio.Event instance.""" self.async_event.set() - self.event.set() - - def set(self) -> None: - """Set the event.""" - self.loop.call_soon_threadsafe(self.async_event.set) - self.event.set() - - def wait(self) -> None: - """Wait for the event.""" - self.event.wait() async def async_wait(self) -> None: """Wait the event async.""" await self.async_event.wait() - def clear(self) -> None: - """Clear the event.""" - self.loop.call_soon_threadsafe(self.async_event.clear) - self.event.clear() - def async_clear(self) -> None: """Clear the event async.""" self.async_event.clear() - self.event.clear() - - def is_set(self) -> bool: - """Return if the event is set.""" - return self.event.is_set() diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index d7d11d8693..950386d969 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio import base64 import binascii -import collections import datetime import functools import gzip @@ -11,14 +10,13 @@ import hashlib import hmac import json import logging -import multiprocessing import os import secrets import shutil import subprocess import threading from pathlib import Path -from typing import Any +from typing import Any, cast import tornado import tornado.concurrent @@ -52,9 +50,9 @@ from esphome.zeroconf import ( DashboardImportDiscovery, DashboardStatus, ) -from .async_adapter import ThreadedAsyncEvent -from .util import friendly_name_slugify, password_hash +from .async_adapter import AsyncEvent +from .util import chunked, friendly_name_slugify, password_hash _LOGGER = logging.getLogger(__name__) @@ -603,7 +601,7 @@ class ImportRequestHandler(BaseHandler): encryption, ) # Make sure the device gets marked online right away - PING_REQUEST.set() + PING_REQUEST.async_set() except FileExistsError: self.set_status(500) self.write("File already exists") @@ -905,15 +903,6 @@ class MainRequestHandler(BaseHandler): ) -def _ping_func(filename, address): - if os.name == "nt": - command = ["ping", "-n", "1", address] - else: - command = ["ping", "-c", "1", address] - rc, _, _ = run_system_command(*command) - return filename, rc == 0 - - class PrometheusServiceDiscoveryHandler(BaseHandler): @authenticated def get(self): @@ -1070,47 +1059,48 @@ class MDNSStatus: self.aiozc = None -class PingStatusThread(threading.Thread): - def run(self): - with multiprocessing.Pool(processes=8) as pool: - while not STOP_EVENT.wait(2): - # Only do pings if somebody has the dashboard open +async def _async_ping_host(host: str) -> bool: + """Ping a host.""" + ping_command = ["ping", "-n" if os.name == "nt" else "-c", "1"] + process = await asyncio.create_subprocess_exec( + *ping_command, + host, + stdin=asyncio.subprocess.DEVNULL, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await process.wait() + return process.returncode == 0 - def callback(ret): - PING_RESULT[ret[0]] = ret[1] - entries = _list_dashboard_entries() - queue = collections.deque() - for entry in entries: - if entry.address is None: - PING_RESULT[entry.filename] = None - continue +class PingStatus: + def __init__(self) -> None: + """Initialize the PingStatus class.""" + super().__init__() + self._loop = asyncio.get_running_loop() - result = pool.apply_async( - _ping_func, (entry.filename, entry.address), callback=callback - ) - queue.append(result) - - while queue: - item = queue[0] - if item.ready(): - queue.popleft() - continue - - try: - item.get(0.1) - except OSError: - # ping not installed - pass - except multiprocessing.TimeoutError: - pass - - if STOP_EVENT.is_set(): - pool.terminate() - return - - PING_REQUEST.wait() - PING_REQUEST.clear() + async def async_run(self) -> None: + """Run the ping status.""" + while not STOP_EVENT.is_set(): + # Only ping if the dashboard is open + await PING_REQUEST.async_wait() + PING_REQUEST.async_clear() + entries = await self._loop.run_in_executor(None, _list_dashboard_entries) + to_ping: list[DashboardEntry] = [ + entry for entry in entries if entry.address is not None + ] + for ping_group in chunked(to_ping, 16): + ping_group = cast(list[DashboardEntry], ping_group) + results = await asyncio.gather( + *(_async_ping_host(entry.address) for entry in ping_group), + return_exceptions=True, + ) + for entry, result in zip(ping_group, results): + if isinstance(result, Exception): + result = False + elif isinstance(result, BaseException): + raise result + PING_RESULT[entry.filename] = result class MqttStatusThread(threading.Thread): @@ -1171,7 +1161,7 @@ class MqttStatusThread(threading.Thread): class PingRequestHandler(BaseHandler): @authenticated def get(self): - PING_REQUEST.set() + PING_REQUEST.async_set() if settings.status_use_mqtt: MQTT_PING_REQUEST.set() self.set_header("content-type", "application/json") @@ -1261,7 +1251,7 @@ class MDNSContainer: PING_RESULT: dict = {} IMPORT_RESULT = {} STOP_EVENT = threading.Event() -PING_REQUEST = ThreadedAsyncEvent() +PING_REQUEST = AsyncEvent() MQTT_PING_REQUEST = threading.Event() MDNS_CONTAINER = MDNSContainer() @@ -1561,10 +1551,10 @@ async def async_start_web_server(args): webbrowser.open(f"http://{args.address}:{args.port}") mdns_task: asyncio.Task | None = None - ping_status_thread: PingStatusThread | None = None + ping_status_task: asyncio.Task | None = None if settings.status_use_ping: - ping_status_thread = PingStatusThread() - ping_status_thread.start() + ping_status = PingStatus() + ping_status_task = asyncio.create_task(ping_status.async_run()) else: mdns_status = MDNSStatus() await mdns_status.async_refresh_hosts() @@ -1581,9 +1571,9 @@ async def async_start_web_server(args): finally: _LOGGER.info("Shutting down...") STOP_EVENT.set() - PING_REQUEST.set() - if ping_status_thread: - ping_status_thread.join() + PING_REQUEST.async_set() + if ping_status_task: + ping_status_task.cancel() MDNS_CONTAINER.set_mdns(None) if mdns_task: mdns_task.cancel() diff --git a/esphome/dashboard/util.py b/esphome/dashboard/util.py index a2ad530b74..7b6572b989 100644 --- a/esphome/dashboard/util.py +++ b/esphome/dashboard/util.py @@ -1,5 +1,9 @@ import hashlib import unicodedata +from collections.abc import Iterable +from functools import partial +from itertools import islice +from typing import Any from esphome.const import ALLOWED_NAME_CHARS @@ -30,3 +34,19 @@ def friendly_name_slugify(value): .strip("-") ) return "".join(c for c in value if c in ALLOWED_NAME_CHARS) + + +def take(take_num: int, iterable: Iterable) -> list[Any]: + """Return first n items of the iterable as a list. + + From itertools recipes + """ + return list(islice(iterable, take_num)) + + +def chunked(iterable: Iterable, chunked_num: int) -> Iterable[Any]: + """Break *iterable* into lists of length *n*. + + From more-itertools + """ + return iter(partial(take, chunked_num, iter(iterable)), [])