From c5a45645a6e75f90b9a50b950130afa405c8da57 Mon Sep 17 00:00:00 2001 From: Markus <974709+Links2004@users.noreply.github.com> Date: Wed, 17 May 2023 06:29:56 +0200 Subject: [PATCH] allow to use MQTT for discovery of IPs if mDNS is no option (#3887) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/__main__.py | 62 ++++++++++-- esphome/components/mdns/__init__.py | 2 +- esphome/components/mqtt/mqtt_client.cpp | 63 ++++++++++++ esphome/components/mqtt/mqtt_client.h | 2 + esphome/const.py | 2 + esphome/dashboard/dashboard.py | 89 +++++++++++++++- esphome/helpers.py | 8 ++ esphome/mqtt.py | 128 ++++++++++++++++++++++-- esphome/storage_json.py | 18 ++++ 9 files changed, 353 insertions(+), 21 deletions(-) diff --git a/esphome/__main__.py b/esphome/__main__.py index 78320a05f0..f82a48e33b 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -18,6 +18,9 @@ from esphome.const import ( CONF_LOGGER, CONF_NAME, CONF_OTA, + CONF_MQTT, + CONF_MDNS, + CONF_DISABLED, CONF_PASSWORD, CONF_PORT, CONF_ESPHOME, @@ -42,7 +45,7 @@ from esphome.log import color, setup_log, Fore _LOGGER = logging.getLogger(__name__) -def choose_prompt(options): +def choose_prompt(options, purpose: str = None): if not options: raise EsphomeError( "Found no valid options for upload/logging, please make sure relevant " @@ -53,7 +56,9 @@ def choose_prompt(options): if len(options) == 1: return options[0][1] - safe_print("Found multiple options, please choose one:") + safe_print( + f'Found multiple options{f" for {purpose}" if purpose else ""}, please choose one:' + ) for i, (desc, _) in enumerate(options): safe_print(f" [{i+1}] {desc}") @@ -72,7 +77,9 @@ def choose_prompt(options): return options[opt - 1][1] -def choose_upload_log_host(default, check_default, show_ota, show_mqtt, show_api): +def choose_upload_log_host( + default, check_default, show_ota, show_mqtt, show_api, purpose: str = None +): options = [] for port in get_serial_ports(): options.append((f"{port.path} ({port.description})", port.path)) @@ -80,7 +87,7 @@ def choose_upload_log_host(default, check_default, show_ota, show_mqtt, show_api options.append((f"Over The Air ({CORE.address})", CORE.address)) if default == "OTA": return CORE.address - if show_mqtt and "mqtt" in CORE.config: + if show_mqtt and CONF_MQTT in CORE.config: options.append((f"MQTT ({CORE.config['mqtt'][CONF_BROKER]})", "MQTT")) if default == "OTA": return "MQTT" @@ -88,7 +95,7 @@ def choose_upload_log_host(default, check_default, show_ota, show_mqtt, show_api return default if check_default is not None and check_default in [opt[1] for opt in options]: return check_default - return choose_prompt(options) + return choose_prompt(options, purpose=purpose) def get_port_type(port): @@ -288,19 +295,30 @@ def upload_program(config, args, host): return 1 # Unknown target platform - from esphome import espota2 - if CONF_OTA not in config: raise EsphomeError( "Cannot upload Over the Air as the config does not include the ota: " "component" ) + from esphome import espota2 + ota_conf = config[CONF_OTA] remote_port = ota_conf[CONF_PORT] password = ota_conf.get(CONF_PASSWORD, "") + + if ( + get_port_type(host) == "MQTT" or config[CONF_MDNS][CONF_DISABLED] + ) and CONF_MQTT in config: + from esphome import mqtt + + host = mqtt.get_esphome_device_ip( + config, args.username, args.password, args.client_id + ) + if getattr(args, "file", None) is not None: return espota2.run_ota(host, remote_port, password, args.file) + return espota2.run_ota(host, remote_port, password, CORE.firmware_bin) @@ -310,6 +328,13 @@ def show_logs(config, args, port): if get_port_type(port) == "SERIAL": return run_miniterm(config, port) if get_port_type(port) == "NETWORK" and "api" in config: + if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config: + from esphome import mqtt + + port = mqtt.get_esphome_device_ip( + config, args.username, args.password, args.client_id + ) + from esphome.components.api.client import run_logs return run_logs(config, port) @@ -374,6 +399,7 @@ def command_upload(args, config): show_ota=True, show_mqtt=False, show_api=False, + purpose="uploading", ) exit_code = upload_program(config, args, port) if exit_code != 0: @@ -382,6 +408,15 @@ def command_upload(args, config): return 0 +def command_discover(args, config): + if "mqtt" in config: + from esphome import mqtt + + return mqtt.show_discover(config, args.username, args.password, args.client_id) + + raise EsphomeError("No discover method configured (mqtt)") + + def command_logs(args, config): port = choose_upload_log_host( default=args.device, @@ -389,6 +424,7 @@ def command_logs(args, config): show_ota=False, show_mqtt=True, show_api=True, + purpose="logging", ) return show_logs(config, args, port) @@ -407,6 +443,7 @@ def command_run(args, config): show_ota=True, show_mqtt=False, show_api=True, + purpose="uploading", ) exit_code = upload_program(config, args, port) if exit_code != 0: @@ -420,6 +457,7 @@ def command_run(args, config): show_ota=False, show_mqtt=True, show_api=True, + purpose="logging", ) return show_logs(config, args, port) @@ -623,6 +661,7 @@ POST_CONFIG_ACTIONS = { "clean": command_clean, "idedata": command_idedata, "rename": command_rename, + "discover": command_discover, } @@ -711,6 +750,15 @@ def parse_args(argv): help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.", ) + parser_discover = subparsers.add_parser( + "discover", + help="Validate the configuration and show all discovered devices.", + parents=[mqtt_options], + ) + parser_discover.add_argument( + "configuration", help="Your YAML configuration file.", nargs=1 + ) + parser_run = subparsers.add_parser( "run", help="Validate the configuration, create a binary, upload it, and start logs.", diff --git a/esphome/components/mdns/__init__.py b/esphome/components/mdns/__init__.py index 66c84da8d8..d9b36c7b09 100644 --- a/esphome/components/mdns/__init__.py +++ b/esphome/components/mdns/__init__.py @@ -6,6 +6,7 @@ from esphome.const import ( CONF_SERVICE, KEY_CORE, KEY_FRAMEWORK_VERSION, + CONF_DISABLED, ) import esphome.codegen as cg import esphome.config_validation as cv @@ -39,7 +40,6 @@ SERVICE_SCHEMA = cv.Schema( } ) -CONF_DISABLED = "disabled" CONFIG_SCHEMA = cv.All( cv.Schema( { diff --git a/esphome/components/mqtt/mqtt_client.cpp b/esphome/components/mqtt/mqtt_client.cpp index af2828ff15..cb5d306976 100644 --- a/esphome/components/mqtt/mqtt_client.cpp +++ b/esphome/components/mqtt/mqtt_client.cpp @@ -7,6 +7,7 @@ #include "esphome/core/application.h" #include "esphome/core/helpers.h" #include "esphome/core/log.h" +#include "esphome/core/version.h" #ifdef USE_LOGGER #include "esphome/components/logger/logger.h" #endif @@ -14,6 +15,13 @@ #include "lwip/err.h" #include "mqtt_component.h" +#ifdef USE_API +#include "esphome/components/api/api_server.h" +#endif +#ifdef USE_DASHBOARD_IMPORT +#include "esphome/components/dashboard_import/dashboard_import.h" +#endif + namespace esphome { namespace mqtt { @@ -58,9 +66,63 @@ void MQTTClientComponent::setup() { } #endif + this->subscribe( + "esphome/discover", [this](const std::string &topic, const std::string &payload) { this->send_device_info_(); }, + 2); + + std::string topic = "esphome/ping/"; + topic.append(App.get_name()); + this->subscribe( + topic, [this](const std::string &topic, const std::string &payload) { this->send_device_info_(); }, 2); + this->last_connected_ = millis(); this->start_dnslookup_(); } + +void MQTTClientComponent::send_device_info_() { + if (!this->is_connected()) { + return; + } + std::string topic = "esphome/discover/"; + topic.append(App.get_name()); + this->publish_json( + topic, + [](JsonObject root) { + auto ip = network::get_ip_address(); + root["ip"] = ip.str(); + root["name"] = App.get_name(); +#ifdef USE_API + root["port"] = api::global_api_server->get_port(); +#endif + root["version"] = ESPHOME_VERSION; + root["mac"] = get_mac_address(); + +#ifdef USE_ESP8266 + root["platform"] = "ESP8266"; +#endif +#ifdef USE_ESP32 + root["platform"] = "ESP32"; +#endif + + root["board"] = ESPHOME_BOARD; +#if defined(USE_WIFI) + root["network"] = "wifi"; +#elif defined(USE_ETHERNET) + root["network"] = "ethernet"; +#endif + +#ifdef ESPHOME_PROJECT_NAME + root["project_name"] = ESPHOME_PROJECT_NAME; + root["project_version"] = ESPHOME_PROJECT_VERSION; +#endif // ESPHOME_PROJECT_NAME + +#ifdef USE_DASHBOARD_IMPORT + root["package_import_url"] = dashboard_import::get_package_import_url(); +#endif + }, + 2, this->discovery_info_.retain); +} + void MQTTClientComponent::dump_config() { ESP_LOGCONFIG(TAG, "MQTT:"); ESP_LOGCONFIG(TAG, " Server Address: %s:%u (%s)", this->credentials_.address.c_str(), this->credentials_.port, @@ -226,6 +288,7 @@ void MQTTClientComponent::check_connected() { delay(100); // NOLINT this->resubscribe_subscriptions_(); + this->send_device_info_(); for (MQTTComponent *component : this->children_) component->schedule_resend_state(); diff --git a/esphome/components/mqtt/mqtt_client.h b/esphome/components/mqtt/mqtt_client.h index 188a027b91..83ed3cc645 100644 --- a/esphome/components/mqtt/mqtt_client.h +++ b/esphome/components/mqtt/mqtt_client.h @@ -251,6 +251,8 @@ class MQTTClientComponent : public Component { void set_on_disconnect(mqtt_on_disconnect_callback_t &&callback); protected: + void send_device_info_(); + /// Reconnect to the MQTT broker if not already connected. void start_connect_(); void start_dnslookup_(); diff --git a/esphome/const.py b/esphome/const.py index 2b976a7079..f784efe820 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -180,6 +180,7 @@ CONF_DIR_PIN = "dir_pin" CONF_DIRECTION = "direction" CONF_DIRECTION_OUTPUT = "direction_output" CONF_DISABLE_CRC = "disable_crc" +CONF_DISABLED = "disabled" CONF_DISABLED_BY_DEFAULT = "disabled_by_default" CONF_DISCONNECT_DELAY = "disconnect_delay" CONF_DISCOVERY = "discovery" @@ -392,6 +393,7 @@ CONF_MAX_SPEED = "max_speed" CONF_MAX_TEMPERATURE = "max_temperature" CONF_MAX_VALUE = "max_value" CONF_MAX_VOLTAGE = "max_voltage" +CONF_MDNS = "mdns" CONF_MEASUREMENT_DURATION = "measurement_duration" CONF_MEASUREMENT_SEQUENCE_NUMBER = "measurement_sequence_number" CONF_MEDIUM = "medium" diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 1a50592a2d..8d8eb74b4b 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -1,4 +1,5 @@ import base64 +import binascii import codecs import collections import functools @@ -76,6 +77,10 @@ class DashboardSettings: def status_use_ping(self): return get_bool_env("ESPHOME_DASHBOARD_USE_PING") + @property + def status_use_mqtt(self): + return get_bool_env("ESPHOME_DASHBOARD_USE_MQTT") + @property def using_ha_addon_auth(self): if not self.on_ha_addon: @@ -583,6 +588,12 @@ class DashboardEntry: return None return self.storage.address + @property + def no_mdns(self): + if self.storage is None: + return None + return self.storage.no_mdns + @property def web_port(self): if self.storage is None: @@ -775,9 +786,12 @@ class MDNSStatusThread(threading.Thread): stat.start() while not STOP_EVENT.is_set(): entries = _list_dashboard_entries() - stat.request_query( - {entry.filename: f"{entry.name}.local." for entry in entries} - ) + hosts = {} + for entry in entries: + if entry.no_mdns is not True: + hosts[entry.filename] = f"{entry.name}.local." + + stat.request_query(hosts) IMPORT_RESULT = imports.import_state PING_REQUEST.wait() @@ -801,6 +815,9 @@ class PingStatusThread(threading.Thread): entries = _list_dashboard_entries() queue = collections.deque() for entry in entries: + if entry.no_mdns is True: + continue + if entry.address is None: PING_RESULT[entry.filename] = None continue @@ -832,10 +849,67 @@ class PingStatusThread(threading.Thread): PING_REQUEST.clear() +class MqttStatusThread(threading.Thread): + def run(self): + from esphome import mqtt + + entries = _list_dashboard_entries() + + config = mqtt.config_from_env() + topic = "esphome/discover/#" + + def on_message(client, userdata, msg): + nonlocal entries + + payload = msg.payload.decode(errors="backslashreplace") + if len(payload) > 0: + data = json.loads(payload) + if "name" not in data: + return + for entry in entries: + if entry.name == data["name"]: + PING_RESULT[entry.filename] = True + return + + def on_connect(client, userdata, flags, return_code): + client.publish("esphome/discover", None, retain=False) + + mqttid = str(binascii.hexlify(os.urandom(6)).decode()) + + client = mqtt.prepare( + config, + [topic], + on_message, + on_connect, + None, + None, + f"esphome-dashboard-{mqttid}", + ) + client.loop_start() + + while not STOP_EVENT.wait(2): + # update entries + entries = _list_dashboard_entries() + + # will be set to true on on_message + for entry in entries: + if entry.no_mdns: + PING_RESULT[entry.filename] = False + + client.publish("esphome/discover", None, retain=False) + MQTT_PING_REQUEST.wait() + MQTT_PING_REQUEST.clear() + + client.disconnect() + client.loop_stop() + + class PingRequestHandler(BaseHandler): @authenticated def get(self): PING_REQUEST.set() + if settings.status_use_mqtt: + MQTT_PING_REQUEST.set() self.set_header("content-type", "application/json") self.write(json.dumps(PING_RESULT)) @@ -910,6 +984,7 @@ PING_RESULT: dict = {} IMPORT_RESULT = {} STOP_EVENT = threading.Event() PING_REQUEST = threading.Event() +MQTT_PING_REQUEST = threading.Event() class LoginHandler(BaseHandler): @@ -1197,6 +1272,11 @@ def start_web_server(args): else: status_thread = MDNSStatusThread() status_thread.start() + + if settings.status_use_mqtt: + status_thread_mqtt = MqttStatusThread() + status_thread_mqtt.start() + try: tornado.ioloop.IOLoop.current().start() except KeyboardInterrupt: @@ -1204,5 +1284,8 @@ def start_web_server(args): STOP_EVENT.set() PING_REQUEST.set() status_thread.join() + if settings.status_use_mqtt: + status_thread_mqtt.join() + MQTT_PING_REQUEST.set() if args.socket is not None: os.remove(args.socket) diff --git a/esphome/helpers.py b/esphome/helpers.py index 884f640d7b..fd8893ad99 100644 --- a/esphome/helpers.py +++ b/esphome/helpers.py @@ -147,6 +147,14 @@ def get_bool_env(var, default=False): return bool(os.getenv(var, default)) +def get_str_env(var, default=None): + return str(os.getenv(var, default)) + + +def get_int_env(var, default=0): + return int(os.getenv(var, default)) + + def is_ha_addon(): return get_bool_env("ESPHOME_IS_HA_ADDON") diff --git a/esphome/mqtt.py b/esphome/mqtt.py index 0ddd976072..166301005d 100644 --- a/esphome/mqtt.py +++ b/esphome/mqtt.py @@ -4,6 +4,7 @@ import logging import ssl import sys import time +import json import paho.mqtt.client as mqtt @@ -24,15 +25,45 @@ from esphome.const import ( from esphome.core import CORE, EsphomeError from esphome.log import color, Fore from esphome.util import safe_print +from esphome.helpers import get_str_env, get_int_env _LOGGER = logging.getLogger(__name__) -def initialize(config, subscriptions, on_message, username, password, client_id): - def on_connect(client, userdata, flags, return_code): +def config_from_env(): + config = { + CONF_MQTT: { + CONF_USERNAME: get_str_env("ESPHOME_DASHBOARD_MQTT_USERNAME"), + CONF_PASSWORD: get_str_env("ESPHOME_DASHBOARD_MQTT_PASSWORD"), + CONF_BROKER: get_str_env("ESPHOME_DASHBOARD_MQTT_BROKER"), + CONF_PORT: get_int_env("ESPHOME_DASHBOARD_MQTT_PORT", 1883), + }, + } + return config + + +def initialize( + config, subscriptions, on_message, on_connect, username, password, client_id +): + client = prepare( + config, subscriptions, on_message, on_connect, username, password, client_id + ) + try: + client.loop_forever() + except KeyboardInterrupt: + pass + return 0 + + +def prepare( + config, subscriptions, on_message, on_connect, username, password, client_id +): + def on_connect_(client, userdata, flags, return_code): _LOGGER.info("Connected to MQTT broker!") for topic in subscriptions: client.subscribe(topic) + if on_connect is not None: + on_connect(client, userdata, flags, return_code) def on_disconnect(client, userdata, result_code): if result_code == 0: @@ -57,7 +88,7 @@ def initialize(config, subscriptions, on_message, username, password, client_id) tries += 1 client = mqtt.Client(client_id or "") - client.on_connect = on_connect + client.on_connect = on_connect_ client.on_message = on_message client.on_disconnect = on_disconnect if username is None: @@ -89,11 +120,88 @@ def initialize(config, subscriptions, on_message, username, password, client_id) except OSError as err: raise EsphomeError(f"Cannot connect to MQTT broker: {err}") from err - try: - client.loop_forever() - except KeyboardInterrupt: - pass - return 0 + return client + + +def show_discover(config, username=None, password=None, client_id=None): + topic = "esphome/discover/#" + _LOGGER.info("Starting log output from %s", topic) + + def on_message(client, userdata, msg): + time_ = datetime.now().time().strftime("[%H:%M:%S]") + payload = msg.payload.decode(errors="backslashreplace") + if len(payload) > 0: + message = time_ + " " + payload + safe_print(message) + + def on_connect(client, userdata, flags, return_code): + _LOGGER.info("Send discover via MQTT broker") + client.publish("esphome/discover", None, retain=False) + + return initialize( + config, [topic], on_message, on_connect, username, password, client_id + ) + + +def get_esphome_device_ip( + config, username=None, password=None, client_id=None, timeout=25 +): + if CONF_MQTT not in config: + raise EsphomeError( + "Cannot discover IP via MQTT as the config does not include the mqtt: " + "component" + ) + if CONF_ESPHOME not in config or CONF_NAME not in config[CONF_ESPHOME]: + raise EsphomeError( + "Cannot discover IP via MQTT as the config does not include the device name: " + "component" + ) + + dev_name = config[CONF_ESPHOME][CONF_NAME] + dev_ip = None + + topic = "esphome/discover/" + dev_name + _LOGGER.info("Starting looking for IP in topic %s", topic) + + def on_message(client, userdata, msg): + nonlocal dev_ip + time_ = datetime.now().time().strftime("[%H:%M:%S]") + payload = msg.payload.decode(errors="backslashreplace") + if len(payload) > 0: + message = time_ + " " + payload + _LOGGER.debug(message) + + data = json.loads(payload) + if "name" not in data or data["name"] != dev_name: + _LOGGER.Warn("Wrong device answer") + return + + if "ip" in data: + dev_ip = data["ip"] + client.disconnect() + + def on_connect(client, userdata, flags, return_code): + topic = "esphome/ping/" + dev_name + _LOGGER.info("Send discover via MQTT broker topic: %s", topic) + client.publish(topic, None, retain=False) + + mqtt_client = prepare( + config, [topic], on_message, on_connect, username, password, client_id + ) + + mqtt_client.loop_start() + while timeout > 0: + if dev_ip is not None: + break + timeout -= 0.250 + time.sleep(0.250) + mqtt_client.loop_stop() + + if dev_ip is None: + raise EsphomeError("Failed to find IP via MQTT") + + _LOGGER.info("Found IP: %s", dev_ip) + return dev_ip def show_logs(config, topic=None, username=None, password=None, client_id=None): @@ -118,7 +226,7 @@ def show_logs(config, topic=None, username=None, password=None, client_id=None): message = time_ + payload safe_print(message) - return initialize(config, [topic], on_message, username, password, client_id) + return initialize(config, [topic], on_message, None, username, password, client_id) def clear_topic(config, topic, username=None, password=None, client_id=None): @@ -142,7 +250,7 @@ def clear_topic(config, topic, username=None, password=None, client_id=None): return client.publish(msg.topic, None, retain=True) - return initialize(config, [topic], on_message, username, password, client_id) + return initialize(config, [topic], on_message, None, username, password, client_id) # From marvinroger/async-mqtt-client -> scripts/get-fingerprint/get-fingerprint.py diff --git a/esphome/storage_json.py b/esphome/storage_json.py index bbdfbbc8a2..acf525203d 100644 --- a/esphome/storage_json.py +++ b/esphome/storage_json.py @@ -10,6 +10,12 @@ from esphome import const from esphome.core import CORE from esphome.helpers import write_file_if_changed + +from esphome.const import ( + CONF_MDNS, + CONF_DISABLED, +) + from esphome.types import CoreType _LOGGER = logging.getLogger(__name__) @@ -46,6 +52,7 @@ class StorageJSON: build_path, firmware_bin_path, loaded_integrations, + no_mdns, ): # Version of the storage JSON schema assert storage_version is None or isinstance(storage_version, int) @@ -75,6 +82,8 @@ class StorageJSON: # A list of strings of names of loaded integrations self.loaded_integrations: list[str] = loaded_integrations self.loaded_integrations.sort() + # Is mDNS disabled + self.no_mdns = no_mdns def as_dict(self): return { @@ -90,6 +99,7 @@ class StorageJSON: "build_path": self.build_path, "firmware_bin_path": self.firmware_bin_path, "loaded_integrations": self.loaded_integrations, + "no_mdns": self.no_mdns, } def to_json(self): @@ -120,6 +130,11 @@ class StorageJSON: build_path=esph.build_path, firmware_bin_path=esph.firmware_bin, loaded_integrations=list(esph.loaded_integrations), + no_mdns=( + CONF_MDNS in esph.config + and CONF_DISABLED in esph.config[CONF_MDNS] + and esph.config[CONF_MDNS][CONF_DISABLED] is True + ), ) @staticmethod @@ -139,6 +154,7 @@ class StorageJSON: build_path=None, firmware_bin_path=None, loaded_integrations=[], + no_mdns=False, ) @staticmethod @@ -159,6 +175,7 @@ class StorageJSON: build_path = storage.get("build_path") firmware_bin_path = storage.get("firmware_bin_path") loaded_integrations = storage.get("loaded_integrations", []) + no_mdns = storage.get("no_mdns", False) return StorageJSON( storage_version, name, @@ -172,6 +189,7 @@ class StorageJSON: build_path, firmware_bin_path, loaded_integrations, + no_mdns, ) @staticmethod