mirror of
https://github.com/esphome/esphome.git
synced 2024-11-28 17:54:13 +01:00
dashboard: convert ping thread to use asyncio (#5749)
This commit is contained in:
parent
642db6d92b
commit
20ea8bf06e
3 changed files with 73 additions and 88 deletions
|
@ -1,18 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
|
||||
class ThreadedAsyncEvent:
|
||||
"""This is a shim to allow the asyncio event to be used in a threaded context.
|
||||
|
||||
When more of the code is moved to asyncio, this can be removed.
|
||||
"""
|
||||
class AsyncEvent:
|
||||
"""This is a shim around asyncio.Event."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the ThreadedAsyncEvent."""
|
||||
self.event = threading.Event()
|
||||
self.async_event: asyncio.Event | None = None
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
|
@ -26,31 +21,11 @@ class ThreadedAsyncEvent:
|
|||
def async_set(self) -> None:
|
||||
"""Set the asyncio.Event instance."""
|
||||
self.async_event.set()
|
||||
self.event.set()
|
||||
|
||||
def set(self) -> None:
|
||||
"""Set the event."""
|
||||
self.loop.call_soon_threadsafe(self.async_event.set)
|
||||
self.event.set()
|
||||
|
||||
def wait(self) -> None:
|
||||
"""Wait for the event."""
|
||||
self.event.wait()
|
||||
|
||||
async def async_wait(self) -> None:
|
||||
"""Wait the event async."""
|
||||
await self.async_event.wait()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the event."""
|
||||
self.loop.call_soon_threadsafe(self.async_event.clear)
|
||||
self.event.clear()
|
||||
|
||||
def async_clear(self) -> None:
|
||||
"""Clear the event async."""
|
||||
self.async_event.clear()
|
||||
self.event.clear()
|
||||
|
||||
def is_set(self) -> bool:
|
||||
"""Return if the event is set."""
|
||||
return self.event.is_set()
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import collections
|
||||
import datetime
|
||||
import functools
|
||||
import gzip
|
||||
|
@ -11,14 +10,13 @@ import hashlib
|
|||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import tornado
|
||||
import tornado.concurrent
|
||||
|
@ -52,9 +50,9 @@ from esphome.zeroconf import (
|
|||
DashboardImportDiscovery,
|
||||
DashboardStatus,
|
||||
)
|
||||
from .async_adapter import ThreadedAsyncEvent
|
||||
|
||||
from .util import friendly_name_slugify, password_hash
|
||||
from .async_adapter import AsyncEvent
|
||||
from .util import chunked, friendly_name_slugify, password_hash
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -603,7 +601,7 @@ class ImportRequestHandler(BaseHandler):
|
|||
encryption,
|
||||
)
|
||||
# Make sure the device gets marked online right away
|
||||
PING_REQUEST.set()
|
||||
PING_REQUEST.async_set()
|
||||
except FileExistsError:
|
||||
self.set_status(500)
|
||||
self.write("File already exists")
|
||||
|
@ -905,15 +903,6 @@ class MainRequestHandler(BaseHandler):
|
|||
)
|
||||
|
||||
|
||||
def _ping_func(filename, address):
|
||||
if os.name == "nt":
|
||||
command = ["ping", "-n", "1", address]
|
||||
else:
|
||||
command = ["ping", "-c", "1", address]
|
||||
rc, _, _ = run_system_command(*command)
|
||||
return filename, rc == 0
|
||||
|
||||
|
||||
class PrometheusServiceDiscoveryHandler(BaseHandler):
|
||||
@authenticated
|
||||
def get(self):
|
||||
|
@ -1070,47 +1059,48 @@ class MDNSStatus:
|
|||
self.aiozc = None
|
||||
|
||||
|
||||
class PingStatusThread(threading.Thread):
|
||||
def run(self):
|
||||
with multiprocessing.Pool(processes=8) as pool:
|
||||
while not STOP_EVENT.wait(2):
|
||||
# Only do pings if somebody has the dashboard open
|
||||
|
||||
def callback(ret):
|
||||
PING_RESULT[ret[0]] = ret[1]
|
||||
|
||||
entries = _list_dashboard_entries()
|
||||
queue = collections.deque()
|
||||
for entry in entries:
|
||||
if entry.address is None:
|
||||
PING_RESULT[entry.filename] = None
|
||||
continue
|
||||
|
||||
result = pool.apply_async(
|
||||
_ping_func, (entry.filename, entry.address), callback=callback
|
||||
async def _async_ping_host(host: str) -> bool:
|
||||
"""Ping a host."""
|
||||
ping_command = ["ping", "-n" if os.name == "nt" else "-c", "1"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*ping_command,
|
||||
host,
|
||||
stdin=asyncio.subprocess.DEVNULL,
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.DEVNULL,
|
||||
)
|
||||
queue.append(result)
|
||||
await process.wait()
|
||||
return process.returncode == 0
|
||||
|
||||
while queue:
|
||||
item = queue[0]
|
||||
if item.ready():
|
||||
queue.popleft()
|
||||
continue
|
||||
|
||||
try:
|
||||
item.get(0.1)
|
||||
except OSError:
|
||||
# ping not installed
|
||||
pass
|
||||
except multiprocessing.TimeoutError:
|
||||
pass
|
||||
class PingStatus:
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the PingStatus class."""
|
||||
super().__init__()
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
if STOP_EVENT.is_set():
|
||||
pool.terminate()
|
||||
return
|
||||
|
||||
PING_REQUEST.wait()
|
||||
PING_REQUEST.clear()
|
||||
async def async_run(self) -> None:
|
||||
"""Run the ping status."""
|
||||
while not STOP_EVENT.is_set():
|
||||
# Only ping if the dashboard is open
|
||||
await PING_REQUEST.async_wait()
|
||||
PING_REQUEST.async_clear()
|
||||
entries = await self._loop.run_in_executor(None, _list_dashboard_entries)
|
||||
to_ping: list[DashboardEntry] = [
|
||||
entry for entry in entries if entry.address is not None
|
||||
]
|
||||
for ping_group in chunked(to_ping, 16):
|
||||
ping_group = cast(list[DashboardEntry], ping_group)
|
||||
results = await asyncio.gather(
|
||||
*(_async_ping_host(entry.address) for entry in ping_group),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for entry, result in zip(ping_group, results):
|
||||
if isinstance(result, Exception):
|
||||
result = False
|
||||
elif isinstance(result, BaseException):
|
||||
raise result
|
||||
PING_RESULT[entry.filename] = result
|
||||
|
||||
|
||||
class MqttStatusThread(threading.Thread):
|
||||
|
@ -1171,7 +1161,7 @@ class MqttStatusThread(threading.Thread):
|
|||
class PingRequestHandler(BaseHandler):
|
||||
@authenticated
|
||||
def get(self):
|
||||
PING_REQUEST.set()
|
||||
PING_REQUEST.async_set()
|
||||
if settings.status_use_mqtt:
|
||||
MQTT_PING_REQUEST.set()
|
||||
self.set_header("content-type", "application/json")
|
||||
|
@ -1261,7 +1251,7 @@ class MDNSContainer:
|
|||
PING_RESULT: dict = {}
|
||||
IMPORT_RESULT = {}
|
||||
STOP_EVENT = threading.Event()
|
||||
PING_REQUEST = ThreadedAsyncEvent()
|
||||
PING_REQUEST = AsyncEvent()
|
||||
MQTT_PING_REQUEST = threading.Event()
|
||||
MDNS_CONTAINER = MDNSContainer()
|
||||
|
||||
|
@ -1561,10 +1551,10 @@ async def async_start_web_server(args):
|
|||
webbrowser.open(f"http://{args.address}:{args.port}")
|
||||
|
||||
mdns_task: asyncio.Task | None = None
|
||||
ping_status_thread: PingStatusThread | None = None
|
||||
ping_status_task: asyncio.Task | None = None
|
||||
if settings.status_use_ping:
|
||||
ping_status_thread = PingStatusThread()
|
||||
ping_status_thread.start()
|
||||
ping_status = PingStatus()
|
||||
ping_status_task = asyncio.create_task(ping_status.async_run())
|
||||
else:
|
||||
mdns_status = MDNSStatus()
|
||||
await mdns_status.async_refresh_hosts()
|
||||
|
@ -1581,9 +1571,9 @@ async def async_start_web_server(args):
|
|||
finally:
|
||||
_LOGGER.info("Shutting down...")
|
||||
STOP_EVENT.set()
|
||||
PING_REQUEST.set()
|
||||
if ping_status_thread:
|
||||
ping_status_thread.join()
|
||||
PING_REQUEST.async_set()
|
||||
if ping_status_task:
|
||||
ping_status_task.cancel()
|
||||
MDNS_CONTAINER.set_mdns(None)
|
||||
if mdns_task:
|
||||
mdns_task.cancel()
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import hashlib
|
||||
import unicodedata
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
from esphome.const import ALLOWED_NAME_CHARS
|
||||
|
||||
|
@ -30,3 +34,19 @@ def friendly_name_slugify(value):
|
|||
.strip("-")
|
||||
)
|
||||
return "".join(c for c in value if c in ALLOWED_NAME_CHARS)
|
||||
|
||||
|
||||
def take(take_num: int, iterable: Iterable) -> list[Any]:
|
||||
"""Return first n items of the iterable as a list.
|
||||
|
||||
From itertools recipes
|
||||
"""
|
||||
return list(islice(iterable, take_num))
|
||||
|
||||
|
||||
def chunked(iterable: Iterable, chunked_num: int) -> Iterable[Any]:
|
||||
"""Break *iterable* into lists of length *n*.
|
||||
|
||||
From more-itertools
|
||||
"""
|
||||
return iter(partial(take, chunked_num, iter(iterable)), [])
|
||||
|
|
Loading…
Reference in a new issue