dashboard: convert ping thread to use asyncio (#5749)

This commit is contained in:
J. Nick Koston 2023-11-14 22:55:33 -06:00 committed by GitHub
parent 642db6d92b
commit 20ea8bf06e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 88 deletions

View file

@ -1,18 +1,13 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import threading
class ThreadedAsyncEvent: class AsyncEvent:
"""This is a shim to allow the asyncio event to be used in a threaded context. """This is a shim around asyncio.Event."""
When more of the code is moved to asyncio, this can be removed.
"""
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the ThreadedAsyncEvent.""" """Initialize the ThreadedAsyncEvent."""
self.event = threading.Event()
self.async_event: asyncio.Event | None = None self.async_event: asyncio.Event | None = None
self.loop: asyncio.AbstractEventLoop | None = None self.loop: asyncio.AbstractEventLoop | None = None
@ -26,31 +21,11 @@ class ThreadedAsyncEvent:
def async_set(self) -> None: def async_set(self) -> None:
"""Set the asyncio.Event instance.""" """Set the asyncio.Event instance."""
self.async_event.set() self.async_event.set()
self.event.set()
def set(self) -> None:
"""Set the event."""
self.loop.call_soon_threadsafe(self.async_event.set)
self.event.set()
def wait(self) -> None:
"""Wait for the event."""
self.event.wait()
async def async_wait(self) -> None: async def async_wait(self) -> None:
"""Wait the event async.""" """Wait the event async."""
await self.async_event.wait() await self.async_event.wait()
def clear(self) -> None:
"""Clear the event."""
self.loop.call_soon_threadsafe(self.async_event.clear)
self.event.clear()
def async_clear(self) -> None: def async_clear(self) -> None:
"""Clear the event async.""" """Clear the event async."""
self.async_event.clear() self.async_event.clear()
self.event.clear()
def is_set(self) -> bool:
"""Return if the event is set."""
return self.event.is_set()

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
import binascii import binascii
import collections
import datetime import datetime
import functools import functools
import gzip import gzip
@ -11,14 +10,13 @@ import hashlib
import hmac import hmac
import json import json
import logging import logging
import multiprocessing
import os import os
import secrets import secrets
import shutil import shutil
import subprocess import subprocess
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, cast
import tornado import tornado
import tornado.concurrent import tornado.concurrent
@ -52,9 +50,9 @@ from esphome.zeroconf import (
DashboardImportDiscovery, DashboardImportDiscovery,
DashboardStatus, DashboardStatus,
) )
from .async_adapter import ThreadedAsyncEvent
from .util import friendly_name_slugify, password_hash from .async_adapter import AsyncEvent
from .util import chunked, friendly_name_slugify, password_hash
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -603,7 +601,7 @@ class ImportRequestHandler(BaseHandler):
encryption, encryption,
) )
# Make sure the device gets marked online right away # Make sure the device gets marked online right away
PING_REQUEST.set() PING_REQUEST.async_set()
except FileExistsError: except FileExistsError:
self.set_status(500) self.set_status(500)
self.write("File already exists") self.write("File already exists")
@ -905,15 +903,6 @@ class MainRequestHandler(BaseHandler):
) )
def _ping_func(filename, address):
if os.name == "nt":
command = ["ping", "-n", "1", address]
else:
command = ["ping", "-c", "1", address]
rc, _, _ = run_system_command(*command)
return filename, rc == 0
class PrometheusServiceDiscoveryHandler(BaseHandler): class PrometheusServiceDiscoveryHandler(BaseHandler):
@authenticated @authenticated
def get(self): def get(self):
@ -1070,47 +1059,48 @@ class MDNSStatus:
self.aiozc = None self.aiozc = None
class PingStatusThread(threading.Thread): async def _async_ping_host(host: str) -> bool:
def run(self): """Ping a host."""
with multiprocessing.Pool(processes=8) as pool: ping_command = ["ping", "-n" if os.name == "nt" else "-c", "1"]
while not STOP_EVENT.wait(2): process = await asyncio.create_subprocess_exec(
# Only do pings if somebody has the dashboard open *ping_command,
host,
stdin=asyncio.subprocess.DEVNULL,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await process.wait()
return process.returncode == 0
def callback(ret):
PING_RESULT[ret[0]] = ret[1]
entries = _list_dashboard_entries() class PingStatus:
queue = collections.deque() def __init__(self) -> None:
for entry in entries: """Initialize the PingStatus class."""
if entry.address is None: super().__init__()
PING_RESULT[entry.filename] = None self._loop = asyncio.get_running_loop()
continue
result = pool.apply_async( async def async_run(self) -> None:
_ping_func, (entry.filename, entry.address), callback=callback """Run the ping status."""
) while not STOP_EVENT.is_set():
queue.append(result) # Only ping if the dashboard is open
await PING_REQUEST.async_wait()
while queue: PING_REQUEST.async_clear()
item = queue[0] entries = await self._loop.run_in_executor(None, _list_dashboard_entries)
if item.ready(): to_ping: list[DashboardEntry] = [
queue.popleft() entry for entry in entries if entry.address is not None
continue ]
for ping_group in chunked(to_ping, 16):
try: ping_group = cast(list[DashboardEntry], ping_group)
item.get(0.1) results = await asyncio.gather(
except OSError: *(_async_ping_host(entry.address) for entry in ping_group),
# ping not installed return_exceptions=True,
pass )
except multiprocessing.TimeoutError: for entry, result in zip(ping_group, results):
pass if isinstance(result, Exception):
result = False
if STOP_EVENT.is_set(): elif isinstance(result, BaseException):
pool.terminate() raise result
return PING_RESULT[entry.filename] = result
PING_REQUEST.wait()
PING_REQUEST.clear()
class MqttStatusThread(threading.Thread): class MqttStatusThread(threading.Thread):
@ -1171,7 +1161,7 @@ class MqttStatusThread(threading.Thread):
class PingRequestHandler(BaseHandler): class PingRequestHandler(BaseHandler):
@authenticated @authenticated
def get(self): def get(self):
PING_REQUEST.set() PING_REQUEST.async_set()
if settings.status_use_mqtt: if settings.status_use_mqtt:
MQTT_PING_REQUEST.set() MQTT_PING_REQUEST.set()
self.set_header("content-type", "application/json") self.set_header("content-type", "application/json")
@ -1261,7 +1251,7 @@ class MDNSContainer:
PING_RESULT: dict = {} PING_RESULT: dict = {}
IMPORT_RESULT = {} IMPORT_RESULT = {}
STOP_EVENT = threading.Event() STOP_EVENT = threading.Event()
PING_REQUEST = ThreadedAsyncEvent() PING_REQUEST = AsyncEvent()
MQTT_PING_REQUEST = threading.Event() MQTT_PING_REQUEST = threading.Event()
MDNS_CONTAINER = MDNSContainer() MDNS_CONTAINER = MDNSContainer()
@ -1561,10 +1551,10 @@ async def async_start_web_server(args):
webbrowser.open(f"http://{args.address}:{args.port}") webbrowser.open(f"http://{args.address}:{args.port}")
mdns_task: asyncio.Task | None = None mdns_task: asyncio.Task | None = None
ping_status_thread: PingStatusThread | None = None ping_status_task: asyncio.Task | None = None
if settings.status_use_ping: if settings.status_use_ping:
ping_status_thread = PingStatusThread() ping_status = PingStatus()
ping_status_thread.start() ping_status_task = asyncio.create_task(ping_status.async_run())
else: else:
mdns_status = MDNSStatus() mdns_status = MDNSStatus()
await mdns_status.async_refresh_hosts() await mdns_status.async_refresh_hosts()
@ -1581,9 +1571,9 @@ async def async_start_web_server(args):
finally: finally:
_LOGGER.info("Shutting down...") _LOGGER.info("Shutting down...")
STOP_EVENT.set() STOP_EVENT.set()
PING_REQUEST.set() PING_REQUEST.async_set()
if ping_status_thread: if ping_status_task:
ping_status_thread.join() ping_status_task.cancel()
MDNS_CONTAINER.set_mdns(None) MDNS_CONTAINER.set_mdns(None)
if mdns_task: if mdns_task:
mdns_task.cancel() mdns_task.cancel()

View file

@ -1,5 +1,9 @@
import hashlib import hashlib
import unicodedata import unicodedata
from collections.abc import Iterable
from functools import partial
from itertools import islice
from typing import Any
from esphome.const import ALLOWED_NAME_CHARS from esphome.const import ALLOWED_NAME_CHARS
@ -30,3 +34,19 @@ def friendly_name_slugify(value):
.strip("-") .strip("-")
) )
return "".join(c for c in value if c in ALLOWED_NAME_CHARS) return "".join(c for c in value if c in ALLOWED_NAME_CHARS)
def take(take_num: int, iterable: Iterable) -> list[Any]:
"""Return first n items of the iterable as a list.
From itertools recipes
"""
return list(islice(iterable, take_num))
def chunked(iterable: Iterable, chunked_num: int) -> Iterable[Any]:
"""Break *iterable* into lists of length *n*.
From more-itertools
"""
return iter(partial(take, chunked_num, iter(iterable)), [])