allow to use MQTT for discovery of IPs if mDNS is no option (#3887)

Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
This commit is contained in:
Markus 2023-05-17 06:29:56 +02:00 committed by GitHub
parent 0de47e2a4e
commit c5a45645a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 353 additions and 21 deletions

View file

@ -18,6 +18,9 @@ from esphome.const import (
CONF_LOGGER,
CONF_NAME,
CONF_OTA,
CONF_MQTT,
CONF_MDNS,
CONF_DISABLED,
CONF_PASSWORD,
CONF_PORT,
CONF_ESPHOME,
@ -42,7 +45,7 @@ from esphome.log import color, setup_log, Fore
_LOGGER = logging.getLogger(__name__)
def choose_prompt(options):
def choose_prompt(options, purpose: str = None):
if not options:
raise EsphomeError(
"Found no valid options for upload/logging, please make sure relevant "
@ -53,7 +56,9 @@ def choose_prompt(options):
if len(options) == 1:
return options[0][1]
safe_print("Found multiple options, please choose one:")
safe_print(
f'Found multiple options{f" for {purpose}" if purpose else ""}, please choose one:'
)
for i, (desc, _) in enumerate(options):
safe_print(f" [{i+1}] {desc}")
@ -72,7 +77,9 @@ def choose_prompt(options):
return options[opt - 1][1]
def choose_upload_log_host(default, check_default, show_ota, show_mqtt, show_api):
def choose_upload_log_host(
default, check_default, show_ota, show_mqtt, show_api, purpose: str = None
):
options = []
for port in get_serial_ports():
options.append((f"{port.path} ({port.description})", port.path))
@ -80,7 +87,7 @@ def choose_upload_log_host(default, check_default, show_ota, show_mqtt, show_api
options.append((f"Over The Air ({CORE.address})", CORE.address))
if default == "OTA":
return CORE.address
if show_mqtt and "mqtt" in CORE.config:
if show_mqtt and CONF_MQTT in CORE.config:
options.append((f"MQTT ({CORE.config['mqtt'][CONF_BROKER]})", "MQTT"))
if default == "OTA":
return "MQTT"
@ -88,7 +95,7 @@ def choose_upload_log_host(default, check_default, show_ota, show_mqtt, show_api
return default
if check_default is not None and check_default in [opt[1] for opt in options]:
return check_default
return choose_prompt(options)
return choose_prompt(options, purpose=purpose)
def get_port_type(port):
@ -288,19 +295,30 @@ def upload_program(config, args, host):
return 1 # Unknown target platform
from esphome import espota2
if CONF_OTA not in config:
raise EsphomeError(
"Cannot upload Over the Air as the config does not include the ota: "
"component"
)
from esphome import espota2
ota_conf = config[CONF_OTA]
remote_port = ota_conf[CONF_PORT]
password = ota_conf.get(CONF_PASSWORD, "")
if (
get_port_type(host) == "MQTT" or config[CONF_MDNS][CONF_DISABLED]
) and CONF_MQTT in config:
from esphome import mqtt
host = mqtt.get_esphome_device_ip(
config, args.username, args.password, args.client_id
)
if getattr(args, "file", None) is not None:
return espota2.run_ota(host, remote_port, password, args.file)
return espota2.run_ota(host, remote_port, password, CORE.firmware_bin)
@ -310,6 +328,13 @@ def show_logs(config, args, port):
if get_port_type(port) == "SERIAL":
return run_miniterm(config, port)
if get_port_type(port) == "NETWORK" and "api" in config:
if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config:
from esphome import mqtt
port = mqtt.get_esphome_device_ip(
config, args.username, args.password, args.client_id
)
from esphome.components.api.client import run_logs
return run_logs(config, port)
@ -374,6 +399,7 @@ def command_upload(args, config):
show_ota=True,
show_mqtt=False,
show_api=False,
purpose="uploading",
)
exit_code = upload_program(config, args, port)
if exit_code != 0:
@ -382,6 +408,15 @@ def command_upload(args, config):
return 0
def command_discover(args, config):
if "mqtt" in config:
from esphome import mqtt
return mqtt.show_discover(config, args.username, args.password, args.client_id)
raise EsphomeError("No discover method configured (mqtt)")
def command_logs(args, config):
port = choose_upload_log_host(
default=args.device,
@ -389,6 +424,7 @@ def command_logs(args, config):
show_ota=False,
show_mqtt=True,
show_api=True,
purpose="logging",
)
return show_logs(config, args, port)
@ -407,6 +443,7 @@ def command_run(args, config):
show_ota=True,
show_mqtt=False,
show_api=True,
purpose="uploading",
)
exit_code = upload_program(config, args, port)
if exit_code != 0:
@ -420,6 +457,7 @@ def command_run(args, config):
show_ota=False,
show_mqtt=True,
show_api=True,
purpose="logging",
)
return show_logs(config, args, port)
@ -623,6 +661,7 @@ POST_CONFIG_ACTIONS = {
"clean": command_clean,
"idedata": command_idedata,
"rename": command_rename,
"discover": command_discover,
}
@ -711,6 +750,15 @@ def parse_args(argv):
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.",
)
parser_discover = subparsers.add_parser(
"discover",
help="Validate the configuration and show all discovered devices.",
parents=[mqtt_options],
)
parser_discover.add_argument(
"configuration", help="Your YAML configuration file.", nargs=1
)
parser_run = subparsers.add_parser(
"run",
help="Validate the configuration, create a binary, upload it, and start logs.",

View file

@ -6,6 +6,7 @@ from esphome.const import (
CONF_SERVICE,
KEY_CORE,
KEY_FRAMEWORK_VERSION,
CONF_DISABLED,
)
import esphome.codegen as cg
import esphome.config_validation as cv
@ -39,7 +40,6 @@ SERVICE_SCHEMA = cv.Schema(
}
)
CONF_DISABLED = "disabled"
CONFIG_SCHEMA = cv.All(
cv.Schema(
{

View file

@ -7,6 +7,7 @@
#include "esphome/core/application.h"
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
#include "esphome/core/version.h"
#ifdef USE_LOGGER
#include "esphome/components/logger/logger.h"
#endif
@ -14,6 +15,13 @@
#include "lwip/err.h"
#include "mqtt_component.h"
#ifdef USE_API
#include "esphome/components/api/api_server.h"
#endif
#ifdef USE_DASHBOARD_IMPORT
#include "esphome/components/dashboard_import/dashboard_import.h"
#endif
namespace esphome {
namespace mqtt {
@ -58,9 +66,63 @@ void MQTTClientComponent::setup() {
}
#endif
this->subscribe(
"esphome/discover", [this](const std::string &topic, const std::string &payload) { this->send_device_info_(); },
2);
std::string topic = "esphome/ping/";
topic.append(App.get_name());
this->subscribe(
topic, [this](const std::string &topic, const std::string &payload) { this->send_device_info_(); }, 2);
this->last_connected_ = millis();
this->start_dnslookup_();
}
void MQTTClientComponent::send_device_info_() {
if (!this->is_connected()) {
return;
}
std::string topic = "esphome/discover/";
topic.append(App.get_name());
this->publish_json(
topic,
[](JsonObject root) {
auto ip = network::get_ip_address();
root["ip"] = ip.str();
root["name"] = App.get_name();
#ifdef USE_API
root["port"] = api::global_api_server->get_port();
#endif
root["version"] = ESPHOME_VERSION;
root["mac"] = get_mac_address();
#ifdef USE_ESP8266
root["platform"] = "ESP8266";
#endif
#ifdef USE_ESP32
root["platform"] = "ESP32";
#endif
root["board"] = ESPHOME_BOARD;
#if defined(USE_WIFI)
root["network"] = "wifi";
#elif defined(USE_ETHERNET)
root["network"] = "ethernet";
#endif
#ifdef ESPHOME_PROJECT_NAME
root["project_name"] = ESPHOME_PROJECT_NAME;
root["project_version"] = ESPHOME_PROJECT_VERSION;
#endif // ESPHOME_PROJECT_NAME
#ifdef USE_DASHBOARD_IMPORT
root["package_import_url"] = dashboard_import::get_package_import_url();
#endif
},
2, this->discovery_info_.retain);
}
void MQTTClientComponent::dump_config() {
ESP_LOGCONFIG(TAG, "MQTT:");
ESP_LOGCONFIG(TAG, " Server Address: %s:%u (%s)", this->credentials_.address.c_str(), this->credentials_.port,
@ -226,6 +288,7 @@ void MQTTClientComponent::check_connected() {
delay(100); // NOLINT
this->resubscribe_subscriptions_();
this->send_device_info_();
for (MQTTComponent *component : this->children_)
component->schedule_resend_state();

View file

@ -251,6 +251,8 @@ class MQTTClientComponent : public Component {
void set_on_disconnect(mqtt_on_disconnect_callback_t &&callback);
protected:
void send_device_info_();
/// Reconnect to the MQTT broker if not already connected.
void start_connect_();
void start_dnslookup_();

View file

@ -180,6 +180,7 @@ CONF_DIR_PIN = "dir_pin"
CONF_DIRECTION = "direction"
CONF_DIRECTION_OUTPUT = "direction_output"
CONF_DISABLE_CRC = "disable_crc"
CONF_DISABLED = "disabled"
CONF_DISABLED_BY_DEFAULT = "disabled_by_default"
CONF_DISCONNECT_DELAY = "disconnect_delay"
CONF_DISCOVERY = "discovery"
@ -392,6 +393,7 @@ CONF_MAX_SPEED = "max_speed"
CONF_MAX_TEMPERATURE = "max_temperature"
CONF_MAX_VALUE = "max_value"
CONF_MAX_VOLTAGE = "max_voltage"
CONF_MDNS = "mdns"
CONF_MEASUREMENT_DURATION = "measurement_duration"
CONF_MEASUREMENT_SEQUENCE_NUMBER = "measurement_sequence_number"
CONF_MEDIUM = "medium"

View file

@ -1,4 +1,5 @@
import base64
import binascii
import codecs
import collections
import functools
@ -76,6 +77,10 @@ class DashboardSettings:
def status_use_ping(self):
return get_bool_env("ESPHOME_DASHBOARD_USE_PING")
@property
def status_use_mqtt(self):
return get_bool_env("ESPHOME_DASHBOARD_USE_MQTT")
@property
def using_ha_addon_auth(self):
if not self.on_ha_addon:
@ -583,6 +588,12 @@ class DashboardEntry:
return None
return self.storage.address
@property
def no_mdns(self):
if self.storage is None:
return None
return self.storage.no_mdns
@property
def web_port(self):
if self.storage is None:
@ -775,9 +786,12 @@ class MDNSStatusThread(threading.Thread):
stat.start()
while not STOP_EVENT.is_set():
entries = _list_dashboard_entries()
stat.request_query(
{entry.filename: f"{entry.name}.local." for entry in entries}
)
hosts = {}
for entry in entries:
if entry.no_mdns is not True:
hosts[entry.filename] = f"{entry.name}.local."
stat.request_query(hosts)
IMPORT_RESULT = imports.import_state
PING_REQUEST.wait()
@ -801,6 +815,9 @@ class PingStatusThread(threading.Thread):
entries = _list_dashboard_entries()
queue = collections.deque()
for entry in entries:
if entry.no_mdns is True:
continue
if entry.address is None:
PING_RESULT[entry.filename] = None
continue
@ -832,10 +849,67 @@ class PingStatusThread(threading.Thread):
PING_REQUEST.clear()
class MqttStatusThread(threading.Thread):
def run(self):
from esphome import mqtt
entries = _list_dashboard_entries()
config = mqtt.config_from_env()
topic = "esphome/discover/#"
def on_message(client, userdata, msg):
nonlocal entries
payload = msg.payload.decode(errors="backslashreplace")
if len(payload) > 0:
data = json.loads(payload)
if "name" not in data:
return
for entry in entries:
if entry.name == data["name"]:
PING_RESULT[entry.filename] = True
return
def on_connect(client, userdata, flags, return_code):
client.publish("esphome/discover", None, retain=False)
mqttid = str(binascii.hexlify(os.urandom(6)).decode())
client = mqtt.prepare(
config,
[topic],
on_message,
on_connect,
None,
None,
f"esphome-dashboard-{mqttid}",
)
client.loop_start()
while not STOP_EVENT.wait(2):
# update entries
entries = _list_dashboard_entries()
# will be set to true on on_message
for entry in entries:
if entry.no_mdns:
PING_RESULT[entry.filename] = False
client.publish("esphome/discover", None, retain=False)
MQTT_PING_REQUEST.wait()
MQTT_PING_REQUEST.clear()
client.disconnect()
client.loop_stop()
class PingRequestHandler(BaseHandler):
@authenticated
def get(self):
PING_REQUEST.set()
if settings.status_use_mqtt:
MQTT_PING_REQUEST.set()
self.set_header("content-type", "application/json")
self.write(json.dumps(PING_RESULT))
@ -910,6 +984,7 @@ PING_RESULT: dict = {}
IMPORT_RESULT = {}
STOP_EVENT = threading.Event()
PING_REQUEST = threading.Event()
MQTT_PING_REQUEST = threading.Event()
class LoginHandler(BaseHandler):
@ -1197,6 +1272,11 @@ def start_web_server(args):
else:
status_thread = MDNSStatusThread()
status_thread.start()
if settings.status_use_mqtt:
status_thread_mqtt = MqttStatusThread()
status_thread_mqtt.start()
try:
tornado.ioloop.IOLoop.current().start()
except KeyboardInterrupt:
@ -1204,5 +1284,8 @@ def start_web_server(args):
STOP_EVENT.set()
PING_REQUEST.set()
status_thread.join()
if settings.status_use_mqtt:
status_thread_mqtt.join()
MQTT_PING_REQUEST.set()
if args.socket is not None:
os.remove(args.socket)

View file

@ -147,6 +147,14 @@ def get_bool_env(var, default=False):
return bool(os.getenv(var, default))
def get_str_env(var, default=None):
return str(os.getenv(var, default))
def get_int_env(var, default=0):
return int(os.getenv(var, default))
def is_ha_addon():
return get_bool_env("ESPHOME_IS_HA_ADDON")

View file

@ -4,6 +4,7 @@ import logging
import ssl
import sys
import time
import json
import paho.mqtt.client as mqtt
@ -24,15 +25,45 @@ from esphome.const import (
from esphome.core import CORE, EsphomeError
from esphome.log import color, Fore
from esphome.util import safe_print
from esphome.helpers import get_str_env, get_int_env
_LOGGER = logging.getLogger(__name__)
def initialize(config, subscriptions, on_message, username, password, client_id):
def on_connect(client, userdata, flags, return_code):
def config_from_env():
config = {
CONF_MQTT: {
CONF_USERNAME: get_str_env("ESPHOME_DASHBOARD_MQTT_USERNAME"),
CONF_PASSWORD: get_str_env("ESPHOME_DASHBOARD_MQTT_PASSWORD"),
CONF_BROKER: get_str_env("ESPHOME_DASHBOARD_MQTT_BROKER"),
CONF_PORT: get_int_env("ESPHOME_DASHBOARD_MQTT_PORT", 1883),
},
}
return config
def initialize(
config, subscriptions, on_message, on_connect, username, password, client_id
):
client = prepare(
config, subscriptions, on_message, on_connect, username, password, client_id
)
try:
client.loop_forever()
except KeyboardInterrupt:
pass
return 0
def prepare(
config, subscriptions, on_message, on_connect, username, password, client_id
):
def on_connect_(client, userdata, flags, return_code):
_LOGGER.info("Connected to MQTT broker!")
for topic in subscriptions:
client.subscribe(topic)
if on_connect is not None:
on_connect(client, userdata, flags, return_code)
def on_disconnect(client, userdata, result_code):
if result_code == 0:
@ -57,7 +88,7 @@ def initialize(config, subscriptions, on_message, username, password, client_id)
tries += 1
client = mqtt.Client(client_id or "")
client.on_connect = on_connect
client.on_connect = on_connect_
client.on_message = on_message
client.on_disconnect = on_disconnect
if username is None:
@ -89,11 +120,88 @@ def initialize(config, subscriptions, on_message, username, password, client_id)
except OSError as err:
raise EsphomeError(f"Cannot connect to MQTT broker: {err}") from err
try:
client.loop_forever()
except KeyboardInterrupt:
pass
return 0
return client
def show_discover(config, username=None, password=None, client_id=None):
topic = "esphome/discover/#"
_LOGGER.info("Starting log output from %s", topic)
def on_message(client, userdata, msg):
time_ = datetime.now().time().strftime("[%H:%M:%S]")
payload = msg.payload.decode(errors="backslashreplace")
if len(payload) > 0:
message = time_ + " " + payload
safe_print(message)
def on_connect(client, userdata, flags, return_code):
_LOGGER.info("Send discover via MQTT broker")
client.publish("esphome/discover", None, retain=False)
return initialize(
config, [topic], on_message, on_connect, username, password, client_id
)
def get_esphome_device_ip(
config, username=None, password=None, client_id=None, timeout=25
):
if CONF_MQTT not in config:
raise EsphomeError(
"Cannot discover IP via MQTT as the config does not include the mqtt: "
"component"
)
if CONF_ESPHOME not in config or CONF_NAME not in config[CONF_ESPHOME]:
raise EsphomeError(
"Cannot discover IP via MQTT as the config does not include the device name: "
"component"
)
dev_name = config[CONF_ESPHOME][CONF_NAME]
dev_ip = None
topic = "esphome/discover/" + dev_name
_LOGGER.info("Starting looking for IP in topic %s", topic)
def on_message(client, userdata, msg):
nonlocal dev_ip
time_ = datetime.now().time().strftime("[%H:%M:%S]")
payload = msg.payload.decode(errors="backslashreplace")
if len(payload) > 0:
message = time_ + " " + payload
_LOGGER.debug(message)
data = json.loads(payload)
if "name" not in data or data["name"] != dev_name:
_LOGGER.Warn("Wrong device answer")
return
if "ip" in data:
dev_ip = data["ip"]
client.disconnect()
def on_connect(client, userdata, flags, return_code):
topic = "esphome/ping/" + dev_name
_LOGGER.info("Send discover via MQTT broker topic: %s", topic)
client.publish(topic, None, retain=False)
mqtt_client = prepare(
config, [topic], on_message, on_connect, username, password, client_id
)
mqtt_client.loop_start()
while timeout > 0:
if dev_ip is not None:
break
timeout -= 0.250
time.sleep(0.250)
mqtt_client.loop_stop()
if dev_ip is None:
raise EsphomeError("Failed to find IP via MQTT")
_LOGGER.info("Found IP: %s", dev_ip)
return dev_ip
def show_logs(config, topic=None, username=None, password=None, client_id=None):
@ -118,7 +226,7 @@ def show_logs(config, topic=None, username=None, password=None, client_id=None):
message = time_ + payload
safe_print(message)
return initialize(config, [topic], on_message, username, password, client_id)
return initialize(config, [topic], on_message, None, username, password, client_id)
def clear_topic(config, topic, username=None, password=None, client_id=None):
@ -142,7 +250,7 @@ def clear_topic(config, topic, username=None, password=None, client_id=None):
return
client.publish(msg.topic, None, retain=True)
return initialize(config, [topic], on_message, username, password, client_id)
return initialize(config, [topic], on_message, None, username, password, client_id)
# From marvinroger/async-mqtt-client -> scripts/get-fingerprint/get-fingerprint.py

View file

@ -10,6 +10,12 @@ from esphome import const
from esphome.core import CORE
from esphome.helpers import write_file_if_changed
from esphome.const import (
CONF_MDNS,
CONF_DISABLED,
)
from esphome.types import CoreType
_LOGGER = logging.getLogger(__name__)
@ -46,6 +52,7 @@ class StorageJSON:
build_path,
firmware_bin_path,
loaded_integrations,
no_mdns,
):
# Version of the storage JSON schema
assert storage_version is None or isinstance(storage_version, int)
@ -75,6 +82,8 @@ class StorageJSON:
# A list of strings of names of loaded integrations
self.loaded_integrations: list[str] = loaded_integrations
self.loaded_integrations.sort()
# Is mDNS disabled
self.no_mdns = no_mdns
def as_dict(self):
return {
@ -90,6 +99,7 @@ class StorageJSON:
"build_path": self.build_path,
"firmware_bin_path": self.firmware_bin_path,
"loaded_integrations": self.loaded_integrations,
"no_mdns": self.no_mdns,
}
def to_json(self):
@ -120,6 +130,11 @@ class StorageJSON:
build_path=esph.build_path,
firmware_bin_path=esph.firmware_bin,
loaded_integrations=list(esph.loaded_integrations),
no_mdns=(
CONF_MDNS in esph.config
and CONF_DISABLED in esph.config[CONF_MDNS]
and esph.config[CONF_MDNS][CONF_DISABLED] is True
),
)
@staticmethod
@ -139,6 +154,7 @@ class StorageJSON:
build_path=None,
firmware_bin_path=None,
loaded_integrations=[],
no_mdns=False,
)
@staticmethod
@ -159,6 +175,7 @@ class StorageJSON:
build_path = storage.get("build_path")
firmware_bin_path = storage.get("firmware_bin_path")
loaded_integrations = storage.get("loaded_integrations", [])
no_mdns = storage.get("no_mdns", False)
return StorageJSON(
storage_version,
name,
@ -172,6 +189,7 @@ class StorageJSON:
build_path,
firmware_bin_path,
loaded_integrations,
no_mdns,
)
@staticmethod