Merge branch 'dev' into hbridge-switch

This commit is contained in:
David Woodhouse 2024-11-08 03:41:30 +00:00 committed by GitHub
commit ec9e6d05b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 117 additions and 57 deletions

View file

@ -38,7 +38,7 @@ from esphome.const import (
SECRETS_FILES, SECRETS_FILES,
) )
from esphome.core import CORE, EsphomeError, coroutine from esphome.core import CORE, EsphomeError, coroutine
from esphome.helpers import indent, is_ip_address, get_bool_env from esphome.helpers import get_bool_env, indent, is_ip_address
from esphome.log import Fore, color, setup_log from esphome.log import Fore, color, setup_log
from esphome.util import ( from esphome.util import (
get_serial_ports, get_serial_ports,
@ -378,7 +378,7 @@ def show_logs(config, args, port):
port = mqtt.get_esphome_device_ip( port = mqtt.get_esphome_device_ip(
config, args.username, args.password, args.client_id config, args.username, args.password, args.client_id
) )[0]
from esphome.components.api.client import run_logs from esphome.components.api.client import run_logs

View file

@ -10,7 +10,7 @@ import sys
import time import time
from esphome.core import EsphomeError from esphome.core import EsphomeError
from esphome.helpers import is_ip_address, resolve_ip_address from esphome.helpers import resolve_ip_address
RESPONSE_OK = 0x00 RESPONSE_OK = 0x00
RESPONSE_REQUEST_AUTH = 0x01 RESPONSE_REQUEST_AUTH = 0x01
@ -311,44 +311,45 @@ def perform_ota(
def run_ota_impl_(remote_host, remote_port, password, filename): def run_ota_impl_(remote_host, remote_port, password, filename):
if is_ip_address(remote_host):
_LOGGER.info("Connecting to %s", remote_host)
ip = remote_host
else:
_LOGGER.info("Resolving IP address of %s", remote_host)
try:
ip = resolve_ip_address(remote_host)
except EsphomeError as err:
_LOGGER.error(
"Error resolving IP address of %s. Is it connected to WiFi?",
remote_host,
)
_LOGGER.error(
"(If this error persists, please set a static IP address: "
"https://esphome.io/components/wifi.html#manual-ips)"
)
raise OTAError(err) from err
_LOGGER.info(" -> %s", ip)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(10.0)
try: try:
sock.connect((ip, remote_port)) res = resolve_ip_address(remote_host, remote_port)
except OSError as err: except EsphomeError as err:
sock.close() _LOGGER.error(
_LOGGER.error("Connecting to %s:%s failed: %s", remote_host, remote_port, err) "Error resolving IP address of %s. Is it connected to WiFi?",
return 1 remote_host,
)
_LOGGER.error(
"(If this error persists, please set a static IP address: "
"https://esphome.io/components/wifi.html#manual-ips)"
)
raise OTAError(err) from err
with open(filename, "rb") as file_handle: for r in res:
af, socktype, _, _, sa = r
_LOGGER.info("Connecting to %s port %s...", sa[0], sa[1])
sock = socket.socket(af, socktype)
sock.settimeout(10.0)
try: try:
perform_ota(sock, password, file_handle, filename) sock.connect(sa)
except OTAError as err: except OSError as err:
_LOGGER.error(str(err))
return 1
finally:
sock.close() sock.close()
_LOGGER.error("Connecting to %s port %s failed: %s", sa[0], sa[1], err)
continue
return 0 _LOGGER.info("Connected to %s", sa[0])
with open(filename, "rb") as file_handle:
try:
perform_ota(sock, password, file_handle, filename)
except OTAError as err:
_LOGGER.error(str(err))
return 1
finally:
sock.close()
return 0
_LOGGER.error("Connection failed.")
return 1
def run_ota(remote_host, remote_port, password, filename): def run_ota(remote_host, remote_port, password, filename):

View file

@ -1,5 +1,6 @@
import codecs import codecs
from contextlib import suppress from contextlib import suppress
import ipaddress
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
@ -91,12 +92,8 @@ def mkdir_p(path):
def is_ip_address(host): def is_ip_address(host):
parts = host.split(".")
if len(parts) != 4:
return False
try: try:
for p in parts: ipaddress.ip_address(host)
int(p)
return True return True
except ValueError: except ValueError:
return False return False
@ -127,25 +124,80 @@ def _resolve_with_zeroconf(host):
return info return info
def resolve_ip_address(host): def addr_preference_(res):
# Trivial alternative to RFC6724 sorting. Put sane IPv6 first, then
# Legacy IP, then IPv6 link-local addresses without an actual link.
sa = res[4]
ip = ipaddress.ip_address(sa[0])
if ip.version == 4:
return 2
if ip.is_link_local and sa[3] == 0:
return 3
return 1
def resolve_ip_address(host, port):
import socket import socket
from esphome.core import EsphomeError from esphome.core import EsphomeError
# There are five cases here. The host argument could be one of:
# • a *list* of IP addresses discovered by MQTT,
# • a single IP address specified by the user,
# • a .local hostname to be resolved by mDNS,
# • a normal hostname to be resolved in DNS, or
# • A URL from which we should extract the hostname.
#
# In each of the first three cases, we end up with IP addresses in
# string form which need to be converted to a 5-tuple to be used
# for the socket connection attempt. The easiest way to construct
# those is to pass the IP address string to getaddrinfo(). Which,
# coincidentally, is how we do hostname lookups in the other cases
# too. So first build a list which contains either IP addresses or
# a single hostname, then call getaddrinfo() on each element of
# that list.
errs = [] errs = []
if isinstance(host, list):
addr_list = host
elif is_ip_address(host):
addr_list = [host]
else:
url = urlparse(host)
if url.scheme != "":
host = url.hostname
if host.endswith(".local"): addr_list = []
if host.endswith(".local"):
try:
_LOGGER.info("Resolving IP address of %s in mDNS", host)
addr_list = _resolve_with_zeroconf(host)
except EsphomeError as err:
errs.append(str(err))
# If not mDNS, or if mDNS failed, use normal DNS
if not addr_list:
addr_list = [host]
# Now we have a list containing either IP addresses or a hostname
res = []
for addr in addr_list:
if not is_ip_address(addr):
_LOGGER.info("Resolving IP address of %s", host)
try: try:
return _resolve_with_zeroconf(host) r = socket.getaddrinfo(addr, port, proto=socket.IPPROTO_TCP)
except EsphomeError as err: except OSError as err:
errs.append(str(err)) errs.append(str(err))
raise EsphomeError(
f"Error resolving IP address: {', '.join(errs)}"
) from err
try: res = res + r
host_url = host if (urlparse(host).scheme != "") else "http://" + host
return socket.gethostbyname(urlparse(host_url).hostname) # Zeroconf tends to give us link-local IPv6 addresses without specifying
except OSError as err: # the link. Put those last in the list to be attempted.
errs.append(str(err)) res.sort(key=addr_preference_)
raise EsphomeError(f"Error resolving IP address: {', '.join(errs)}") from err return res
def get_bool_env(var, default=False): def get_bool_env(var, default=False):

View file

@ -175,8 +175,15 @@ def get_esphome_device_ip(
_LOGGER.Warn("Wrong device answer") _LOGGER.Warn("Wrong device answer")
return return
if "ip" in data: dev_ip = []
dev_ip = data["ip"] key = "ip"
n = 0
while key in data:
dev_ip.append(data[key])
n = n + 1
key = "ip" + str(n)
if dev_ip:
client.disconnect() client.disconnect()
def on_connect(client, userdata, flags, return_code): def on_connect(client, userdata, flags, return_code):

View file

@ -182,8 +182,8 @@ class EsphomeZeroconf(Zeroconf):
if ( if (
info.load_from_cache(self) info.load_from_cache(self)
or (timeout and info.request(self, timeout * 1000)) or (timeout and info.request(self, timeout * 1000))
) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): ) and (addresses := info.parsed_scoped_addresses(IPVersion.All)):
return str(addresses[0]) return addresses
return None return None
@ -194,6 +194,6 @@ class AsyncEsphomeZeroconf(AsyncZeroconf):
if ( if (
info.load_from_cache(self.zeroconf) info.load_from_cache(self.zeroconf)
or (timeout and await info.async_request(self.zeroconf, timeout * 1000)) or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): ) and (addresses := info.parsed_scoped_addresses(IPVersion.All)):
return str(addresses[0]) return addresses
return None return None