Use mdns or freshen cache when device connection is OTA

Since we already have a service browser running, we likely
already know the IP of the deivce we want to connect to so
we can replace OTA with the address to avoid the esphome
app having to look it up again
This commit is contained in:
J. Nick Koston 2023-11-09 23:13:27 -06:00
parent 0f19450ab4
commit 4e21532578
No known key found for this signature in database
2 changed files with 87 additions and 45 deletions

View file

@ -32,7 +32,7 @@ import tornado.web
import tornado.websocket import tornado.websocket
import yaml import yaml
from tornado.log import access_log from tornado.log import access_log
from typing import Any
from esphome import const, platformio_api, util, yaml_util from esphome import const, platformio_api, util, yaml_util
from esphome.core import CORE from esphome.core import CORE
from esphome.helpers import get_bool_env, mkdir_p, run_system_command from esphome.helpers import get_bool_env, mkdir_p, run_system_command
@ -398,19 +398,40 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
raise NotImplementedError raise NotImplementedError
class EsphomeLogsHandler(EsphomeCommandWebSocket): DASHBOARD_COMMAND = ["esphome", "--dashboard"]
def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"])
class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
"""Base class for commands that require a port."""
def run_command(self, args: list[str], json_message: dict[str, Any]) -> list[str]:
"""Build the command to run."""
configuration = json_message["configuration"]
config_file = settings.rel_path(configuration)
port = json_message["port"]
if (
port == "OTA"
and (mdns := MDNS_CONTAINER.get_mdns())
and (host_name := mdns.filename_to_host_name_thread_safe(configuration))
and (address := mdns.resolve_host_thread_safe(host_name))
):
port = address
return [ return [
"esphome", *DASHBOARD_COMMAND,
"--dashboard", *args,
"logs",
config_file, config_file,
"--device", "--device",
json_message["port"], port,
] ]
class EsphomeLogsHandler(EsphomePortCommandWebSocket):
def build_command(self, json_message: dict[str, Any]) -> list[str]:
"""Build the command to run."""
return self.run_command(["logs"], json_message)
class EsphomeRenameHandler(EsphomeCommandWebSocket): class EsphomeRenameHandler(EsphomeCommandWebSocket):
old_name: str old_name: str
@ -418,8 +439,7 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket):
config_file = settings.rel_path(json_message["configuration"]) config_file = settings.rel_path(json_message["configuration"])
self.old_name = json_message["configuration"] self.old_name = json_message["configuration"]
return [ return [
"esphome", *DASHBOARD_COMMAND,
"--dashboard",
"rename", "rename",
config_file, config_file,
json_message["newName"], json_message["newName"],
@ -435,36 +455,22 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket):
PING_RESULT.pop(self.old_name, None) PING_RESULT.pop(self.old_name, None)
class EsphomeUploadHandler(EsphomeCommandWebSocket): class EsphomeUploadHandler(EsphomePortCommandWebSocket):
def build_command(self, json_message): def build_command(self, json_message: dict[str, Any]) -> list[str]:
config_file = settings.rel_path(json_message["configuration"]) """Build the command to run."""
return [ return self.run_command(["upload"], json_message)
"esphome",
"--dashboard",
"upload",
config_file,
"--device",
json_message["port"],
]
class EsphomeRunHandler(EsphomeCommandWebSocket): class EsphomeRunHandler(EsphomePortCommandWebSocket):
def build_command(self, json_message): def build_command(self, json_message: dict[str, Any]) -> list[str]:
config_file = settings.rel_path(json_message["configuration"]) """Build the command to run."""
return [ return self.run_command(["run"], json_message)
"esphome",
"--dashboard",
"run",
config_file,
"--device",
json_message["port"],
]
class EsphomeCompileHandler(EsphomeCommandWebSocket): class EsphomeCompileHandler(EsphomeCommandWebSocket):
def build_command(self, json_message): def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"]) config_file = settings.rel_path(json_message["configuration"])
command = ["esphome", "--dashboard", "compile"] command = [*DASHBOARD_COMMAND, "compile"]
if json_message.get("only_generate", False): if json_message.get("only_generate", False):
command.append("--only-generate") command.append("--only-generate")
command.append(config_file) command.append(config_file)
@ -474,7 +480,7 @@ class EsphomeCompileHandler(EsphomeCommandWebSocket):
class EsphomeValidateHandler(EsphomeCommandWebSocket): class EsphomeValidateHandler(EsphomeCommandWebSocket):
def build_command(self, json_message): def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"]) config_file = settings.rel_path(json_message["configuration"])
command = ["esphome", "--dashboard", "config", config_file] command = [*DASHBOARD_COMMAND, "config", config_file]
if not settings.streamer_mode: if not settings.streamer_mode:
command.append("--show-secrets") command.append("--show-secrets")
return command return command
@ -483,28 +489,28 @@ class EsphomeValidateHandler(EsphomeCommandWebSocket):
class EsphomeCleanMqttHandler(EsphomeCommandWebSocket): class EsphomeCleanMqttHandler(EsphomeCommandWebSocket):
def build_command(self, json_message): def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"]) config_file = settings.rel_path(json_message["configuration"])
return ["esphome", "--dashboard", "clean-mqtt", config_file] return [*DASHBOARD_COMMAND, "clean-mqtt", config_file]
class EsphomeCleanHandler(EsphomeCommandWebSocket): class EsphomeCleanHandler(EsphomeCommandWebSocket):
def build_command(self, json_message): def build_command(self, json_message):
config_file = settings.rel_path(json_message["configuration"]) config_file = settings.rel_path(json_message["configuration"])
return ["esphome", "--dashboard", "clean", config_file] return [*DASHBOARD_COMMAND, "clean", config_file]
class EsphomeVscodeHandler(EsphomeCommandWebSocket): class EsphomeVscodeHandler(EsphomeCommandWebSocket):
def build_command(self, json_message): def build_command(self, json_message):
return ["esphome", "--dashboard", "-q", "vscode", "dummy"] return [*DASHBOARD_COMMAND, "-q", "vscode", "dummy"]
class EsphomeAceEditorHandler(EsphomeCommandWebSocket): class EsphomeAceEditorHandler(EsphomeCommandWebSocket):
def build_command(self, json_message): def build_command(self, json_message):
return ["esphome", "--dashboard", "-q", "vscode", "--ace", settings.config_dir] return [*DASHBOARD_COMMAND, "-q", "vscode", "--ace", settings.config_dir]
class EsphomeUpdateAllHandler(EsphomeCommandWebSocket): class EsphomeUpdateAllHandler(EsphomeCommandWebSocket):
def build_command(self, json_message): def build_command(self, json_message):
return ["esphome", "--dashboard", "update-all", settings.config_dir] return [*DASHBOARD_COMMAND, "update-all", settings.config_dir]
class SerialPortRequestHandler(BaseHandler): class SerialPortRequestHandler(BaseHandler):
@ -964,23 +970,38 @@ class BoardsRequestHandler(BaseHandler):
class MDNSStatusThread(threading.Thread): class MDNSStatusThread(threading.Thread):
def __init__(self): """Thread that updates the mdns status."""
def __init__(self) -> None:
"""Initialize the MDNSStatusThread.""" """Initialize the MDNSStatusThread."""
super().__init__() super().__init__()
self.zeroconf: EsphomeZeroconf | None = None
# This is the current mdns state for each host (True, False, None) # This is the current mdns state for each host (True, False, None)
self.host_mdns_state: dict[str, bool | None] = {} self.host_mdns_state: dict[str, bool | None] = {}
# This is the hostnames to filenames mapping # This is the hostnames to filenames mapping
self.host_name_to_filename: dict[str, str] = {} self.host_name_to_filename: dict[str, str] = {}
self.filename_to_host_name: dict[str, str] = {}
# This is a set of host names to track (i.e no_mdns = false) # This is a set of host names to track (i.e no_mdns = false)
self.host_name_with_mdns_enabled: set[set] = set() self.host_name_with_mdns_enabled: set[set] = set()
self._refresh_hosts() self._refresh_hosts()
def filename_to_host_name_thread_safe(self, filename: str) -> str | None:
"""Resolve a filename to an address in a thread-safe manner."""
return self.filename_to_host_name.get(filename)
def resolve_host_thread_safe(self, host_name: str) -> str | None:
"""Resolve a host name to an address in a thread-safe manner."""
if zc := self.zeroconf:
return zc.resolve_host(host_name)
return None
def _refresh_hosts(self): def _refresh_hosts(self):
"""Refresh the hosts to track.""" """Refresh the hosts to track."""
entries = _list_dashboard_entries() entries = _list_dashboard_entries()
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
host_mdns_state = self.host_mdns_state host_mdns_state = self.host_mdns_state
host_name_to_filename = self.host_name_to_filename host_name_to_filename = self.host_name_to_filename
filename_to_host_name = self.filename_to_host_name
for entry in entries: for entry in entries:
name = entry.name name = entry.name
@ -1003,11 +1024,13 @@ class MDNSStatusThread(threading.Thread):
# so when we get an mdns update we can map it back # so when we get an mdns update we can map it back
# to the filename # to the filename
host_name_to_filename[name] = filename host_name_to_filename[name] = filename
filename_to_host_name[filename] = name
def run(self): def run(self):
global IMPORT_RESULT global IMPORT_RESULT
zc = EsphomeZeroconf() zc = EsphomeZeroconf()
self.zeroconf = zc
host_mdns_state = self.host_mdns_state host_mdns_state = self.host_mdns_state
host_name_to_filename = self.host_name_to_filename host_name_to_filename = self.host_name_to_filename
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
@ -1035,6 +1058,7 @@ class MDNSStatusThread(threading.Thread):
browser.cancel() browser.cancel()
zc.close() zc.close()
self.zeroconf = None
class PingStatusThread(threading.Thread): class PingStatusThread(threading.Thread):
@ -1211,11 +1235,26 @@ class UndoDeleteRequestHandler(BaseHandler):
shutil.move(os.path.join(trash_path, configuration), config_file) shutil.move(os.path.join(trash_path, configuration), config_file)
class MDNSContainer:
def __init__(self) -> None:
"""Initialize the MDNSContainer."""
self._mdns: MDNSStatusThread | None = None
def set_mdns(self, mdns: MDNSStatusThread) -> None:
"""Set the MDNSStatusThread instance."""
self._mdns = mdns
def get_mdns(self) -> MDNSStatusThread | None:
"""Return the MDNSStatusThread instance."""
return self._mdns
PING_RESULT: dict = {} PING_RESULT: dict = {}
IMPORT_RESULT = {} IMPORT_RESULT = {}
STOP_EVENT = threading.Event() STOP_EVENT = threading.Event()
PING_REQUEST = threading.Event() PING_REQUEST = threading.Event()
MQTT_PING_REQUEST = threading.Event() MQTT_PING_REQUEST = threading.Event()
MDNS_CONTAINER = MDNSContainer()
class LoginHandler(BaseHandler): class LoginHandler(BaseHandler):
@ -1506,6 +1545,7 @@ def start_web_server(args):
status_thread = PingStatusThread() status_thread = PingStatusThread()
else: else:
status_thread = MDNSStatusThread() status_thread = MDNSStatusThread()
MDNS_CONTAINER.set_mdns(status_thread)
status_thread.start() status_thread.start()
if settings.status_use_mqtt: if settings.status_use_mqtt:
@ -1519,6 +1559,7 @@ def start_web_server(args):
STOP_EVENT.set() STOP_EVENT.set()
PING_REQUEST.set() PING_REQUEST.set()
status_thread.join() status_thread.join()
MDNS_CONTAINER.set_mdns(None)
if settings.status_use_mqtt: if settings.status_use_mqtt:
status_thread_mqtt.join() status_thread_mqtt.join()
MQTT_PING_REQUEST.set() MQTT_PING_REQUEST.set()

View file

@ -147,12 +147,13 @@ class DashboardImportDiscovery:
class EsphomeZeroconf(Zeroconf): class EsphomeZeroconf(Zeroconf):
def resolve_host(self, host: str, timeout=3.0): def resolve_host(self, host: str, timeout: float = 3.0) -> str | None:
"""Resolve a host name to an IP address.""" """Resolve a host name to an IP address."""
name = host.partition(".")[0] name = host.partition(".")[0]
info = HostResolver(f"{name}.{ESPHOME_SERVICE_TYPE}", ESPHOME_SERVICE_TYPE) info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}")
if (info.load_from_cache(self) or info.request(self, timeout * 1000)) and ( if (
addresses := info.ip_addresses_by_version(IPVersion.V4Only) info.load_from_cache(self)
): or (timeout and info.request(self, timeout * 1000))
) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)):
return str(addresses[0]) return str(addresses[0])
return None return None