Use mdns or freshen cache when device connection is OTA

Since we already have a service browser running, we likely
already know the IP of the deivce we want to connect to so
we can replace OTA with the address to avoid the esphome
app having to look it up again
This commit is contained in:
J. Nick Koston 2023-11-09 23:13:27 -06:00
parent 0f19450ab4
commit 4e21532578
No known key found for this signature in database
2 changed files with 87 additions and 45 deletions

View file

@ -32,7 +32,7 @@ import tornado.web
import tornado.websocket
import yaml
from tornado.log import access_log
from typing import Any
from esphome import const, platformio_api, util, yaml_util
from esphome.core import CORE
from esphome.helpers import get_bool_env, mkdir_p, run_system_command
@ -398,19 +398,40 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
raise NotImplementedError
class EsphomeLogsHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
DASHBOARD_COMMAND = ["esphome", "--dashboard"]
class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
"""Base class for commands that require a port."""
def run_command(self, args: list[str], json_message: dict[str, Any]) -> list[str]:
"""Build the command to run."""
configuration = json_message["configuration"]
config_file = settings.rel_path(configuration)
port = json_message["port"]
if (
port == "OTA"
and (mdns := MDNS_CONTAINER.get_mdns())
and (host_name := mdns.filename_to_host_name_thread_safe(configuration))
and (address := mdns.resolve_host_thread_safe(host_name))
):
port = address
return [
"esphome",
"--dashboard",
"logs",
*DASHBOARD_COMMAND,
*args,
config_file,
"--device",
json_message["port"],
port,
]
class EsphomeLogsHandler(EsphomePortCommandWebSocket):
def build_command(self, json_message: dict[str, Any]) -> list[str]:
"""Build the command to run."""
return self.run_command(["logs"], json_message)
class EsphomeRenameHandler(EsphomeCommandWebSocket):
old_name: str
@ -418,8 +439,7 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket):
config_file = settings.rel_path(json_message["configuration"])
self.old_name = json_message["configuration"]
return [
"esphome",
"--dashboard",
*DASHBOARD_COMMAND,
"rename",
config_file,
json_message["newName"],
@ -435,36 +455,22 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket):
PING_RESULT.pop(self.old_name, None)
class EsphomeUploadHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
return [
"esphome",
"--dashboard",
"upload",
config_file,
"--device",
json_message["port"],
]
class EsphomeUploadHandler(EsphomePortCommandWebSocket):
def build_command(self, json_message: dict[str, Any]) -> list[str]:
"""Build the command to run."""
return self.run_command(["upload"], json_message)
class EsphomeRunHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
return [
"esphome",
"--dashboard",
"run",
config_file,
"--device",
json_message["port"],
]
class EsphomeRunHandler(EsphomePortCommandWebSocket):
def build_command(self, json_message: dict[str, Any]) -> list[str]:
"""Build the command to run."""
return self.run_command(["run"], json_message)
class EsphomeCompileHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
command = ["esphome", "--dashboard", "compile"]
command = [*DASHBOARD_COMMAND, "compile"]
if json_message.get("only_generate", False):
command.append("--only-generate")
command.append(config_file)
@ -474,7 +480,7 @@ class EsphomeCompileHandler(EsphomeCommandWebSocket):
class EsphomeValidateHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
command = ["esphome", "--dashboard", "config", config_file]
command = [*DASHBOARD_COMMAND, "config", config_file]
if not settings.streamer_mode:
command.append("--show-secrets")
return command
@ -483,28 +489,28 @@ class EsphomeValidateHandler(EsphomeCommandWebSocket):
class EsphomeCleanMqttHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
return ["esphome", "--dashboard", "clean-mqtt", config_file]
return [*DASHBOARD_COMMAND, "clean-mqtt", config_file]
class EsphomeCleanHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
return ["esphome", "--dashboard", "clean", config_file]
return [*DASHBOARD_COMMAND, "clean", config_file]
class EsphomeVscodeHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
return ["esphome", "--dashboard", "-q", "vscode", "dummy"]
return [*DASHBOARD_COMMAND, "-q", "vscode", "dummy"]
class EsphomeAceEditorHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
return ["esphome", "--dashboard", "-q", "vscode", "--ace", settings.config_dir]
return [*DASHBOARD_COMMAND, "-q", "vscode", "--ace", settings.config_dir]
class EsphomeUpdateAllHandler(EsphomeCommandWebSocket):
def build_command(self, json_message):
return ["esphome", "--dashboard", "update-all", settings.config_dir]
return [*DASHBOARD_COMMAND, "update-all", settings.config_dir]
class SerialPortRequestHandler(BaseHandler):
@ -964,23 +970,38 @@ class BoardsRequestHandler(BaseHandler):
class MDNSStatusThread(threading.Thread):
def __init__(self):
"""Thread that updates the mdns status."""
def __init__(self) -> None:
"""Initialize the MDNSStatusThread."""
super().__init__()
self.zeroconf: EsphomeZeroconf | None = None
# This is the current mdns state for each host (True, False, None)
self.host_mdns_state: dict[str, bool | None] = {}
# This is the hostnames to filenames mapping
self.host_name_to_filename: dict[str, str] = {}
self.filename_to_host_name: dict[str, str] = {}
# This is a set of host names to track (i.e no_mdns = false)
self.host_name_with_mdns_enabled: set[set] = set()
self._refresh_hosts()
def filename_to_host_name_thread_safe(self, filename: str) -> str | None:
"""Resolve a filename to an address in a thread-safe manner."""
return self.filename_to_host_name.get(filename)
def resolve_host_thread_safe(self, host_name: str) -> str | None:
"""Resolve a host name to an address in a thread-safe manner."""
if zc := self.zeroconf:
return zc.resolve_host(host_name)
return None
def _refresh_hosts(self):
"""Refresh the hosts to track."""
entries = _list_dashboard_entries()
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
host_mdns_state = self.host_mdns_state
host_name_to_filename = self.host_name_to_filename
filename_to_host_name = self.filename_to_host_name
for entry in entries:
name = entry.name
@ -1003,11 +1024,13 @@ class MDNSStatusThread(threading.Thread):
# so when we get an mdns update we can map it back
# to the filename
host_name_to_filename[name] = filename
filename_to_host_name[filename] = name
def run(self):
global IMPORT_RESULT
zc = EsphomeZeroconf()
self.zeroconf = zc
host_mdns_state = self.host_mdns_state
host_name_to_filename = self.host_name_to_filename
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
@ -1035,6 +1058,7 @@ class MDNSStatusThread(threading.Thread):
browser.cancel()
zc.close()
self.zeroconf = None
class PingStatusThread(threading.Thread):
@ -1211,11 +1235,26 @@ class UndoDeleteRequestHandler(BaseHandler):
shutil.move(os.path.join(trash_path, configuration), config_file)
class MDNSContainer:
def __init__(self) -> None:
"""Initialize the MDNSContainer."""
self._mdns: MDNSStatusThread | None = None
def set_mdns(self, mdns: MDNSStatusThread) -> None:
"""Set the MDNSStatusThread instance."""
self._mdns = mdns
def get_mdns(self) -> MDNSStatusThread | None:
"""Return the MDNSStatusThread instance."""
return self._mdns
PING_RESULT: dict = {}
IMPORT_RESULT = {}
STOP_EVENT = threading.Event()
PING_REQUEST = threading.Event()
MQTT_PING_REQUEST = threading.Event()
MDNS_CONTAINER = MDNSContainer()
class LoginHandler(BaseHandler):
@ -1506,6 +1545,7 @@ def start_web_server(args):
status_thread = PingStatusThread()
else:
status_thread = MDNSStatusThread()
MDNS_CONTAINER.set_mdns(status_thread)
status_thread.start()
if settings.status_use_mqtt:
@ -1519,6 +1559,7 @@ def start_web_server(args):
STOP_EVENT.set()
PING_REQUEST.set()
status_thread.join()
MDNS_CONTAINER.set_mdns(None)
if settings.status_use_mqtt:
status_thread_mqtt.join()
MQTT_PING_REQUEST.set()

View file

@ -147,12 +147,13 @@ class DashboardImportDiscovery:
class EsphomeZeroconf(Zeroconf):
def resolve_host(self, host: str, timeout=3.0):
def resolve_host(self, host: str, timeout: float = 3.0) -> str | None:
"""Resolve a host name to an IP address."""
name = host.partition(".")[0]
info = HostResolver(f"{name}.{ESPHOME_SERVICE_TYPE}", ESPHOME_SERVICE_TYPE)
if (info.load_from_cache(self) or info.request(self, timeout * 1000)) and (
addresses := info.ip_addresses_by_version(IPVersion.V4Only)
):
info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}")
if (
info.load_from_cache(self)
or (timeout and info.request(self, timeout * 1000))
) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)):
return str(addresses[0])
return None