mirror of
https://github.com/esphome/esphome.git
synced 2024-11-10 09:17:46 +01:00
dashboard: convert ping thread to use asyncio (#5749)
This commit is contained in:
parent
642db6d92b
commit
20ea8bf06e
3 changed files with 73 additions and 88 deletions
|
@ -1,18 +1,13 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadedAsyncEvent:
|
class AsyncEvent:
|
||||||
"""This is a shim to allow the asyncio event to be used in a threaded context.
|
"""This is a shim around asyncio.Event."""
|
||||||
|
|
||||||
When more of the code is moved to asyncio, this can be removed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the ThreadedAsyncEvent."""
|
"""Initialize the ThreadedAsyncEvent."""
|
||||||
self.event = threading.Event()
|
|
||||||
self.async_event: asyncio.Event | None = None
|
self.async_event: asyncio.Event | None = None
|
||||||
self.loop: asyncio.AbstractEventLoop | None = None
|
self.loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
|
||||||
|
@ -26,31 +21,11 @@ class ThreadedAsyncEvent:
|
||||||
def async_set(self) -> None:
|
def async_set(self) -> None:
|
||||||
"""Set the asyncio.Event instance."""
|
"""Set the asyncio.Event instance."""
|
||||||
self.async_event.set()
|
self.async_event.set()
|
||||||
self.event.set()
|
|
||||||
|
|
||||||
def set(self) -> None:
|
|
||||||
"""Set the event."""
|
|
||||||
self.loop.call_soon_threadsafe(self.async_event.set)
|
|
||||||
self.event.set()
|
|
||||||
|
|
||||||
def wait(self) -> None:
|
|
||||||
"""Wait for the event."""
|
|
||||||
self.event.wait()
|
|
||||||
|
|
||||||
async def async_wait(self) -> None:
|
async def async_wait(self) -> None:
|
||||||
"""Wait the event async."""
|
"""Wait the event async."""
|
||||||
await self.async_event.wait()
|
await self.async_event.wait()
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Clear the event."""
|
|
||||||
self.loop.call_soon_threadsafe(self.async_event.clear)
|
|
||||||
self.event.clear()
|
|
||||||
|
|
||||||
def async_clear(self) -> None:
|
def async_clear(self) -> None:
|
||||||
"""Clear the event async."""
|
"""Clear the event async."""
|
||||||
self.async_event.clear()
|
self.async_event.clear()
|
||||||
self.event.clear()
|
|
||||||
|
|
||||||
def is_set(self) -> bool:
|
|
||||||
"""Return if the event is set."""
|
|
||||||
return self.event.is_set()
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
import collections
|
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
import gzip
|
import gzip
|
||||||
|
@ -11,14 +10,13 @@ import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import tornado
|
import tornado
|
||||||
import tornado.concurrent
|
import tornado.concurrent
|
||||||
|
@ -52,9 +50,9 @@ from esphome.zeroconf import (
|
||||||
DashboardImportDiscovery,
|
DashboardImportDiscovery,
|
||||||
DashboardStatus,
|
DashboardStatus,
|
||||||
)
|
)
|
||||||
from .async_adapter import ThreadedAsyncEvent
|
|
||||||
|
|
||||||
from .util import friendly_name_slugify, password_hash
|
from .async_adapter import AsyncEvent
|
||||||
|
from .util import chunked, friendly_name_slugify, password_hash
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -603,7 +601,7 @@ class ImportRequestHandler(BaseHandler):
|
||||||
encryption,
|
encryption,
|
||||||
)
|
)
|
||||||
# Make sure the device gets marked online right away
|
# Make sure the device gets marked online right away
|
||||||
PING_REQUEST.set()
|
PING_REQUEST.async_set()
|
||||||
except FileExistsError:
|
except FileExistsError:
|
||||||
self.set_status(500)
|
self.set_status(500)
|
||||||
self.write("File already exists")
|
self.write("File already exists")
|
||||||
|
@ -905,15 +903,6 @@ class MainRequestHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ping_func(filename, address):
|
|
||||||
if os.name == "nt":
|
|
||||||
command = ["ping", "-n", "1", address]
|
|
||||||
else:
|
|
||||||
command = ["ping", "-c", "1", address]
|
|
||||||
rc, _, _ = run_system_command(*command)
|
|
||||||
return filename, rc == 0
|
|
||||||
|
|
||||||
|
|
||||||
class PrometheusServiceDiscoveryHandler(BaseHandler):
|
class PrometheusServiceDiscoveryHandler(BaseHandler):
|
||||||
@authenticated
|
@authenticated
|
||||||
def get(self):
|
def get(self):
|
||||||
|
@ -1070,47 +1059,48 @@ class MDNSStatus:
|
||||||
self.aiozc = None
|
self.aiozc = None
|
||||||
|
|
||||||
|
|
||||||
class PingStatusThread(threading.Thread):
|
async def _async_ping_host(host: str) -> bool:
|
||||||
def run(self):
|
"""Ping a host."""
|
||||||
with multiprocessing.Pool(processes=8) as pool:
|
ping_command = ["ping", "-n" if os.name == "nt" else "-c", "1"]
|
||||||
while not STOP_EVENT.wait(2):
|
process = await asyncio.create_subprocess_exec(
|
||||||
# Only do pings if somebody has the dashboard open
|
*ping_command,
|
||||||
|
host,
|
||||||
def callback(ret):
|
stdin=asyncio.subprocess.DEVNULL,
|
||||||
PING_RESULT[ret[0]] = ret[1]
|
stdout=asyncio.subprocess.DEVNULL,
|
||||||
|
stderr=asyncio.subprocess.DEVNULL,
|
||||||
entries = _list_dashboard_entries()
|
|
||||||
queue = collections.deque()
|
|
||||||
for entry in entries:
|
|
||||||
if entry.address is None:
|
|
||||||
PING_RESULT[entry.filename] = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
result = pool.apply_async(
|
|
||||||
_ping_func, (entry.filename, entry.address), callback=callback
|
|
||||||
)
|
)
|
||||||
queue.append(result)
|
await process.wait()
|
||||||
|
return process.returncode == 0
|
||||||
|
|
||||||
while queue:
|
|
||||||
item = queue[0]
|
|
||||||
if item.ready():
|
|
||||||
queue.popleft()
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
class PingStatus:
|
||||||
item.get(0.1)
|
def __init__(self) -> None:
|
||||||
except OSError:
|
"""Initialize the PingStatus class."""
|
||||||
# ping not installed
|
super().__init__()
|
||||||
pass
|
self._loop = asyncio.get_running_loop()
|
||||||
except multiprocessing.TimeoutError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if STOP_EVENT.is_set():
|
async def async_run(self) -> None:
|
||||||
pool.terminate()
|
"""Run the ping status."""
|
||||||
return
|
while not STOP_EVENT.is_set():
|
||||||
|
# Only ping if the dashboard is open
|
||||||
PING_REQUEST.wait()
|
await PING_REQUEST.async_wait()
|
||||||
PING_REQUEST.clear()
|
PING_REQUEST.async_clear()
|
||||||
|
entries = await self._loop.run_in_executor(None, _list_dashboard_entries)
|
||||||
|
to_ping: list[DashboardEntry] = [
|
||||||
|
entry for entry in entries if entry.address is not None
|
||||||
|
]
|
||||||
|
for ping_group in chunked(to_ping, 16):
|
||||||
|
ping_group = cast(list[DashboardEntry], ping_group)
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(_async_ping_host(entry.address) for entry in ping_group),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
for entry, result in zip(ping_group, results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
result = False
|
||||||
|
elif isinstance(result, BaseException):
|
||||||
|
raise result
|
||||||
|
PING_RESULT[entry.filename] = result
|
||||||
|
|
||||||
|
|
||||||
class MqttStatusThread(threading.Thread):
|
class MqttStatusThread(threading.Thread):
|
||||||
|
@ -1171,7 +1161,7 @@ class MqttStatusThread(threading.Thread):
|
||||||
class PingRequestHandler(BaseHandler):
|
class PingRequestHandler(BaseHandler):
|
||||||
@authenticated
|
@authenticated
|
||||||
def get(self):
|
def get(self):
|
||||||
PING_REQUEST.set()
|
PING_REQUEST.async_set()
|
||||||
if settings.status_use_mqtt:
|
if settings.status_use_mqtt:
|
||||||
MQTT_PING_REQUEST.set()
|
MQTT_PING_REQUEST.set()
|
||||||
self.set_header("content-type", "application/json")
|
self.set_header("content-type", "application/json")
|
||||||
|
@ -1261,7 +1251,7 @@ class MDNSContainer:
|
||||||
PING_RESULT: dict = {}
|
PING_RESULT: dict = {}
|
||||||
IMPORT_RESULT = {}
|
IMPORT_RESULT = {}
|
||||||
STOP_EVENT = threading.Event()
|
STOP_EVENT = threading.Event()
|
||||||
PING_REQUEST = ThreadedAsyncEvent()
|
PING_REQUEST = AsyncEvent()
|
||||||
MQTT_PING_REQUEST = threading.Event()
|
MQTT_PING_REQUEST = threading.Event()
|
||||||
MDNS_CONTAINER = MDNSContainer()
|
MDNS_CONTAINER = MDNSContainer()
|
||||||
|
|
||||||
|
@ -1561,10 +1551,10 @@ async def async_start_web_server(args):
|
||||||
webbrowser.open(f"http://{args.address}:{args.port}")
|
webbrowser.open(f"http://{args.address}:{args.port}")
|
||||||
|
|
||||||
mdns_task: asyncio.Task | None = None
|
mdns_task: asyncio.Task | None = None
|
||||||
ping_status_thread: PingStatusThread | None = None
|
ping_status_task: asyncio.Task | None = None
|
||||||
if settings.status_use_ping:
|
if settings.status_use_ping:
|
||||||
ping_status_thread = PingStatusThread()
|
ping_status = PingStatus()
|
||||||
ping_status_thread.start()
|
ping_status_task = asyncio.create_task(ping_status.async_run())
|
||||||
else:
|
else:
|
||||||
mdns_status = MDNSStatus()
|
mdns_status = MDNSStatus()
|
||||||
await mdns_status.async_refresh_hosts()
|
await mdns_status.async_refresh_hosts()
|
||||||
|
@ -1581,9 +1571,9 @@ async def async_start_web_server(args):
|
||||||
finally:
|
finally:
|
||||||
_LOGGER.info("Shutting down...")
|
_LOGGER.info("Shutting down...")
|
||||||
STOP_EVENT.set()
|
STOP_EVENT.set()
|
||||||
PING_REQUEST.set()
|
PING_REQUEST.async_set()
|
||||||
if ping_status_thread:
|
if ping_status_task:
|
||||||
ping_status_thread.join()
|
ping_status_task.cancel()
|
||||||
MDNS_CONTAINER.set_mdns(None)
|
MDNS_CONTAINER.set_mdns(None)
|
||||||
if mdns_task:
|
if mdns_task:
|
||||||
mdns_task.cancel()
|
mdns_task.cancel()
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from functools import partial
|
||||||
|
from itertools import islice
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from esphome.const import ALLOWED_NAME_CHARS
|
from esphome.const import ALLOWED_NAME_CHARS
|
||||||
|
|
||||||
|
@ -30,3 +34,19 @@ def friendly_name_slugify(value):
|
||||||
.strip("-")
|
.strip("-")
|
||||||
)
|
)
|
||||||
return "".join(c for c in value if c in ALLOWED_NAME_CHARS)
|
return "".join(c for c in value if c in ALLOWED_NAME_CHARS)
|
||||||
|
|
||||||
|
|
||||||
|
def take(take_num: int, iterable: Iterable) -> list[Any]:
|
||||||
|
"""Return first n items of the iterable as a list.
|
||||||
|
|
||||||
|
From itertools recipes
|
||||||
|
"""
|
||||||
|
return list(islice(iterable, take_num))
|
||||||
|
|
||||||
|
|
||||||
|
def chunked(iterable: Iterable, chunked_num: int) -> Iterable[Any]:
|
||||||
|
"""Break *iterable* into lists of length *n*.
|
||||||
|
|
||||||
|
From more-itertools
|
||||||
|
"""
|
||||||
|
return iter(partial(take, chunked_num, iter(iterable)), [])
|
||||||
|
|
Loading…
Reference in a new issue