From 5ff7c8418c584fc10f49caaaf9736ffffb321c00 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Thu, 11 Nov 2021 08:55:45 +1300 Subject: [PATCH] Implement Improv via Serial component (#2423) Co-authored-by: Paulus Schoutsen --- CODEOWNERS | 1 + .../captive_portal/captive_portal.cpp | 1 + esphome/components/esp32/__init__.py | 2 + esphome/components/esp32/const.py | 8 + esphome/components/esp8266/__init__.py | 1 + esphome/components/improv/improv.cpp | 12 +- esphome/components/improv/improv.h | 10 +- esphome/components/improv_serial/__init__.py | 33 +++ .../improv_serial/improv_serial_component.cpp | 250 ++++++++++++++++++ .../improv_serial/improv_serial_component.h | 69 +++++ esphome/components/logger/logger.cpp | 2 +- esphome/components/web_server/__init__.py | 2 + esphome/components/wifi/__init__.py | 3 +- esphome/components/wifi/wifi_component.cpp | 2 - esphome/core/defines.h | 1 + script/ci-custom.py | 6 +- tests/test3.yaml | 2 + 17 files changed, 391 insertions(+), 14 deletions(-) create mode 100644 esphome/components/improv_serial/__init__.py create mode 100644 esphome/components/improv_serial/improv_serial_component.cpp create mode 100644 esphome/components/improv_serial/improv_serial_component.h diff --git a/CODEOWNERS b/CODEOWNERS index 8f98fe1f7f..18b4564280 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -73,6 +73,7 @@ esphome/components/homeassistant/* @OttoWinter esphome/components/hrxl_maxsonar_wr/* @netmikey esphome/components/i2c/* @esphome/core esphome/components/improv/* @jesserockz +esphome/components/improv_serial/* @esphome/core esphome/components/inkbird_ibsth1_mini/* @fkirill esphome/components/inkplate6/* @jesserockz esphome/components/integration/* @OttoWinter diff --git a/esphome/components/captive_portal/captive_portal.cpp b/esphome/components/captive_portal/captive_portal.cpp index 9e00adae3d..ad4c32bb1f 100644 --- a/esphome/components/captive_portal/captive_portal.cpp +++ b/esphome/components/captive_portal/captive_portal.cpp @@ -67,6 +67,7 @@ void CaptivePortal::handle_wifisave(AsyncWebServerRequest *request) { ESP_LOGI(TAG, " SSID='%s'", ssid.c_str()); ESP_LOGI(TAG, " Password=" LOG_SECRET("'%s'"), psk.c_str()); wifi::global_wifi_component->save_wifi_sta(ssid, psk); + wifi::global_wifi_component->start_scanning(); request->redirect("/?save=true"); } diff --git a/esphome/components/esp32/__init__.py b/esphome/components/esp32/__init__.py index d84663b2d6..1c249476e7 100644 --- a/esphome/components/esp32/__init__.py +++ b/esphome/components/esp32/__init__.py @@ -28,6 +28,7 @@ from .const import ( # noqa KEY_SDKCONFIG_OPTIONS, KEY_VARIANT, VARIANT_ESP32C3, + VARIANT_FRIENDLY, VARIANTS, ) from .boards import BOARD_TO_VARIANT @@ -287,6 +288,7 @@ async def to_code(config): cg.add_build_flag("-DUSE_ESP32") cg.add_define("ESPHOME_BOARD", config[CONF_BOARD]) cg.add_build_flag(f"-DUSE_ESP32_VARIANT_{config[CONF_VARIANT]}") + cg.add_define("ESPHOME_VARIANT", VARIANT_FRIENDLY[config[CONF_VARIANT]]) cg.add_platformio_option("lib_ldf_mode", "off") diff --git a/esphome/components/esp32/const.py b/esphome/components/esp32/const.py index b82f03bf68..d92b449ee9 100644 --- a/esphome/components/esp32/const.py +++ b/esphome/components/esp32/const.py @@ -18,4 +18,12 @@ VARIANTS = [ VARIANT_ESP32H2, ] +VARIANT_FRIENDLY = { + VARIANT_ESP32: "ESP32", + VARIANT_ESP32S2: "ESP32-S2", + VARIANT_ESP32S3: "ESP32-S3", + VARIANT_ESP32C3: "ESP32-C3", + VARIANT_ESP32H2: "ESP32-H2", +} + esp32_ns = cg.esphome_ns.namespace("esp32") diff --git a/esphome/components/esp8266/__init__.py b/esphome/components/esp8266/__init__.py index 5b97d2d9d5..34c792499d 100644 --- a/esphome/components/esp8266/__init__.py +++ b/esphome/components/esp8266/__init__.py @@ -156,6 +156,7 @@ async def to_code(config): cg.add_platformio_option("board", config[CONF_BOARD]) cg.add_build_flag("-DUSE_ESP8266") cg.add_define("ESPHOME_BOARD", config[CONF_BOARD]) + cg.add_define("ESPHOME_VARIANT", "ESP8266") conf = config[CONF_FRAMEWORK] cg.add_platformio_option("framework", "arduino") diff --git a/esphome/components/improv/improv.cpp b/esphome/components/improv/improv.cpp index 4f6ed7702d..94068bc626 100644 --- a/esphome/components/improv/improv.cpp +++ b/esphome/components/improv/improv.cpp @@ -7,11 +7,13 @@ ImprovCommand parse_improv_data(const std::vector &data) { } ImprovCommand parse_improv_data(const uint8_t *data, size_t length) { + ImprovCommand improv_command; Command command = (Command) data[0]; uint8_t data_length = data[1]; if (data_length != length - 3) { - return {.command = UNKNOWN}; + improv_command.command = UNKNOWN; + return improv_command; } uint8_t checksum = data[length - 1]; @@ -22,7 +24,8 @@ ImprovCommand parse_improv_data(const uint8_t *data, size_t length) { } if ((uint8_t) calculated_checksum != checksum) { - return {.command = BAD_CHECKSUM}; + improv_command.command = BAD_CHECKSUM; + return improv_command; } if (command == WIFI_SETTINGS) { @@ -39,9 +42,8 @@ ImprovCommand parse_improv_data(const uint8_t *data, size_t length) { return {.command = command, .ssid = ssid, .password = password}; } - return { - .command = command, - }; + improv_command.command = command; + return improv_command; } std::vector build_rpc_response(Command command, const std::vector &datum) { diff --git a/esphome/components/improv/improv.h b/esphome/components/improv/improv.h index 0ead80e2cf..542eb82bd3 100644 --- a/esphome/components/improv/improv.h +++ b/esphome/components/improv/improv.h @@ -1,8 +1,8 @@ #pragma once -#ifdef USE_ARDUINO +#ifdef ARDUINO #include "WString.h" -#endif // USE_ARDUINO +#endif // ARDUINO #include #include @@ -38,6 +38,8 @@ enum Command : uint8_t { UNKNOWN = 0x00, WIFI_SETTINGS = 0x01, IDENTIFY = 0x02, + GET_CURRENT_STATE = 0x02, + GET_DEVICE_INFO = 0x03, BAD_CHECKSUM = 0xFF, }; @@ -53,8 +55,8 @@ ImprovCommand parse_improv_data(const std::vector &data); ImprovCommand parse_improv_data(const uint8_t *data, size_t length); std::vector build_rpc_response(Command command, const std::vector &datum); -#ifdef USE_ARDUINO +#ifdef ARDUINO std::vector build_rpc_response(Command command, const std::vector &datum); -#endif // USE_ARDUINO +#endif // ARDUINO } // namespace improv diff --git a/esphome/components/improv_serial/__init__.py b/esphome/components/improv_serial/__init__.py new file mode 100644 index 0000000000..b1cdc2d93e --- /dev/null +++ b/esphome/components/improv_serial/__init__.py @@ -0,0 +1,33 @@ +from esphome.const import CONF_BAUD_RATE, CONF_ID, CONF_LOGGER +import esphome.codegen as cg +import esphome.config_validation as cv +import esphome.final_validate as fv + +CODEOWNERS = ["@esphome/core"] +DEPENDENCIES = ["logger", "wifi"] +AUTO_LOAD = ["improv"] + +improv_serial_ns = cg.esphome_ns.namespace("improv_serial") + +ImprovSerialComponent = improv_serial_ns.class_("ImprovSerialComponent", cg.Component) + +CONFIG_SCHEMA = cv.Schema( + { + cv.GenerateID(): cv.declare_id(ImprovSerialComponent), + } +).extend(cv.COMPONENT_SCHEMA) + + +def validate_logger_baud_rate(config): + logger_conf = fv.full_config.get()[CONF_LOGGER] + if logger_conf[CONF_BAUD_RATE] == 0: + raise cv.Invalid("improv_serial requires the logger baud_rate to be not 0") + return config + + +FINAL_VALIDATE_SCHEMA = validate_logger_baud_rate + + +async def to_code(config): + var = cg.new_Pvariable(config[CONF_ID]) + await cg.register_component(var, config) diff --git a/esphome/components/improv_serial/improv_serial_component.cpp b/esphome/components/improv_serial/improv_serial_component.cpp new file mode 100644 index 0000000000..a12f1bd83b --- /dev/null +++ b/esphome/components/improv_serial/improv_serial_component.cpp @@ -0,0 +1,250 @@ +#include "improv_serial_component.h" + +#include "esphome/core/application.h" +#include "esphome/core/defines.h" +#include "esphome/core/hal.h" +#include "esphome/core/log.h" +#include "esphome/core/version.h" + +#include "esphome/components/logger/logger.h" + +namespace esphome { +namespace improv_serial { + +static const char *const TAG = "improv_serial"; + +void ImprovSerialComponent::setup() { + global_improv_serial_component = this; +#ifdef USE_ARDUINO + this->hw_serial_ = logger::global_logger->get_hw_serial(); +#endif +#ifdef USE_ESP_IDF + this->uart_num_ = logger::global_logger->get_uart_num(); +#endif + + if (wifi::global_wifi_component->has_sta()) { + this->state_ = improv::STATE_PROVISIONED; + } +} + +void ImprovSerialComponent::dump_config() { ESP_LOGCONFIG(TAG, "Improv Serial:"); } + +int ImprovSerialComponent::available_() { +#ifdef USE_ARDUINO + return this->hw_serial_->available(); +#endif +#ifdef USE_ESP_IDF + size_t available; + uart_get_buffered_data_len(this->uart_num_, &available); + return available; +#endif +} + +uint8_t ImprovSerialComponent::read_byte_() { + uint8_t data; +#ifdef USE_ARDUINO + this->hw_serial_->readBytes(&data, 1); +#endif +#ifdef USE_ESP_IDF + uart_read_bytes(this->uart_num_, &data, 1, 20 / portTICK_RATE_MS); +#endif + return data; +} + +void ImprovSerialComponent::write_data_(std::vector &data) { + data.push_back('\n'); +#ifdef USE_ARDUINO + this->hw_serial_->write(data.data(), data.size()); +#endif +#ifdef USE_ESP_IDF + uart_write_bytes(this->uart_num_, data.data(), data.size()); +#endif +} + +void ImprovSerialComponent::loop() { + const uint32_t now = millis(); + if (now - this->last_read_byte_ > 50) { + this->rx_buffer_.clear(); + this->last_read_byte_ = now; + } + + while (this->available_()) { + uint8_t byte = this->read_byte_(); + if (this->parse_improv_serial_byte_(byte)) { + this->last_read_byte_ = now; + } else { + this->rx_buffer_.clear(); + } + } + + if (this->state_ == improv::STATE_PROVISIONING) { + if (wifi::global_wifi_component->is_connected()) { + wifi::global_wifi_component->save_wifi_sta(this->connecting_sta_.get_ssid(), + this->connecting_sta_.get_password()); + this->connecting_sta_ = {}; + this->cancel_timeout("wifi-connect-timeout"); + this->set_state_(improv::STATE_PROVISIONED); + + std::vector url = this->build_rpc_settings_response_(improv::WIFI_SETTINGS); + this->send_response_(url); + } + } +} + +std::vector ImprovSerialComponent::build_rpc_settings_response_(improv::Command command) { + std::string url = "https://my.home-assistant.io/redirect/config_flow_start?domain=esphome"; + std::vector urls = {url}; +#ifdef USE_WEBSERVER + auto ip = wifi::global_wifi_component->wifi_sta_ip(); + std::string webserver_url = "http://" + ip.str() + ":" + to_string(WEBSERVER_PORT); + urls.push_back(webserver_url); +#endif + std::vector data = improv::build_rpc_response(command, urls); + return data; +} + +std::vector ImprovSerialComponent::build_version_info_() { + std::vector infos = {"ESPHome", ESPHOME_VERSION, ESPHOME_VARIANT, App.get_name()}; + std::vector data = improv::build_rpc_response(improv::GET_DEVICE_INFO, infos); + return data; +}; + +bool ImprovSerialComponent::parse_improv_serial_byte_(uint8_t byte) { + size_t at = this->rx_buffer_.size(); + this->rx_buffer_.push_back(byte); + ESP_LOGD(TAG, "Improv Serial byte: 0x%02X", byte); + const uint8_t *raw = &this->rx_buffer_[0]; + if (at == 0) + return byte == 'I'; + if (at == 1) + return byte == 'M'; + if (at == 2) + return byte == 'P'; + if (at == 3) + return byte == 'R'; + if (at == 4) + return byte == 'O'; + if (at == 5) + return byte == 'V'; + + if (at == 6) + return byte == IMPROV_SERIAL_VERSION; + + if (at == 7) + return true; + uint8_t type = raw[7]; + + if (at == 8) + return true; + uint8_t data_len = raw[8]; + + if (at < 8 + data_len) + return true; + + if (at == 8 + data_len) { + if (type == TYPE_RPC) { + this->set_error_(improv::ERROR_NONE); + auto command = improv::parse_improv_data(&raw[9], data_len); + return this->parse_improv_payload_(command); + } + } + return true; +} + +bool ImprovSerialComponent::parse_improv_payload_(improv::ImprovCommand &command) { + switch (command.command) { + case improv::BAD_CHECKSUM: + ESP_LOGW(TAG, "Error decoding Improv payload"); + this->set_error_(improv::ERROR_INVALID_RPC); + return false; + case improv::WIFI_SETTINGS: { + wifi::WiFiAP sta{}; + sta.set_ssid(command.ssid); + sta.set_password(command.password); + this->connecting_sta_ = sta; + + wifi::global_wifi_component->set_sta(sta); + wifi::global_wifi_component->start_scanning(); + this->set_state_(improv::STATE_PROVISIONING); + ESP_LOGD(TAG, "Received Improv wifi settings ssid=%s, password=" LOG_SECRET("%s"), command.ssid.c_str(), + command.password.c_str()); + + auto f = std::bind(&ImprovSerialComponent::on_wifi_connect_timeout_, this); + this->set_timeout("wifi-connect-timeout", 30000, f); + return true; + } + case improv::GET_CURRENT_STATE: + this->set_state_(this->state_); + if (this->state_ == improv::STATE_PROVISIONED) { + std::vector url = this->build_rpc_settings_response_(improv::GET_CURRENT_STATE); + this->send_response_(url); + } + return true; + case improv::GET_DEVICE_INFO: { + std::vector info = this->build_version_info_(); + this->send_response_(info); + return true; + } + default: { + ESP_LOGW(TAG, "Unknown Improv payload"); + this->set_error_(improv::ERROR_UNKNOWN_RPC); + return false; + } + } +} + +void ImprovSerialComponent::set_state_(improv::State state) { + this->state_ = state; + + std::vector data = {'I', 'M', 'P', 'R', 'O', 'V'}; + data.resize(11); + data[6] = IMPROV_SERIAL_VERSION; + data[7] = TYPE_CURRENT_STATE; + data[8] = 1; + data[9] = state; + + uint8_t checksum = 0x00; + for (uint8_t d : data) + checksum += d; + data[10] = checksum; + + this->write_data_(data); +} + +void ImprovSerialComponent::set_error_(improv::Error error) { + std::vector data = {'I', 'M', 'P', 'R', 'O', 'V'}; + data.resize(11); + data[6] = IMPROV_SERIAL_VERSION; + data[7] = TYPE_ERROR_STATE; + data[8] = 1; + data[9] = error; + + uint8_t checksum = 0x00; + for (uint8_t d : data) + checksum += d; + data[10] = checksum; + this->write_data_(data); +} + +void ImprovSerialComponent::send_response_(std::vector &response) { + std::vector data = {'I', 'M', 'P', 'R', 'O', 'V'}; + data.resize(9); + data[6] = IMPROV_SERIAL_VERSION; + data[7] = TYPE_RPC_RESPONSE; + data[8] = response.size(); + data.insert(data.end(), response.begin(), response.end()); + this->write_data_(data); +} + +void ImprovSerialComponent::on_wifi_connect_timeout_() { + this->set_error_(improv::ERROR_UNABLE_TO_CONNECT); + this->set_state_(improv::STATE_AUTHORIZED); + ESP_LOGW(TAG, "Timed out trying to connect to given WiFi network"); + wifi::global_wifi_component->clear_sta(); +} + +ImprovSerialComponent *global_improv_serial_component = // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + +} // namespace improv_serial +} // namespace esphome diff --git a/esphome/components/improv_serial/improv_serial_component.h b/esphome/components/improv_serial/improv_serial_component.h new file mode 100644 index 0000000000..539674e2d3 --- /dev/null +++ b/esphome/components/improv_serial/improv_serial_component.h @@ -0,0 +1,69 @@ +#pragma once + +#include "esphome/components/improv/improv.h" +#include "esphome/components/wifi/wifi_component.h" +#include "esphome/core/component.h" +#include "esphome/core/defines.h" +#include "esphome/core/helpers.h" + +#ifdef USE_ARDUINO +#include +#endif +#ifdef USE_ESP_IDF +#include +#endif + +namespace esphome { +namespace improv_serial { + +enum ImprovSerialType : uint8_t { + TYPE_CURRENT_STATE = 0x01, + TYPE_ERROR_STATE = 0x02, + TYPE_RPC = 0x03, + TYPE_RPC_RESPONSE = 0x04 +}; + +static const uint8_t IMPROV_SERIAL_VERSION = 1; + +class ImprovSerialComponent : public Component { + public: + void setup() override; + void loop() override; + void dump_config() override; + + float get_setup_priority() const override { return setup_priority::HARDWARE; } + + protected: + bool parse_improv_serial_byte_(uint8_t byte); + bool parse_improv_payload_(improv::ImprovCommand &command); + + void set_state_(improv::State state); + void set_error_(improv::Error error); + void send_response_(std::vector &response); + void on_wifi_connect_timeout_(); + + std::vector build_rpc_settings_response_(improv::Command command); + std::vector build_version_info_(); + + int available_(); + uint8_t read_byte_(); + void write_data_(std::vector &data); + +#ifdef USE_ARDUINO + HardwareSerial *hw_serial_{nullptr}; +#endif +#ifdef USE_ESP_IDF + uart_port_t uart_num_; +#endif + + std::vector rx_buffer_; + uint32_t last_read_byte_{0}; + wifi::WiFiAP connecting_sta_; + improv::State state_{improv::STATE_AUTHORIZED}; +}; + +extern ImprovSerialComponent + *global_improv_serial_component; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + +} // namespace improv_serial +} // namespace esphome diff --git a/esphome/components/logger/logger.cpp b/esphome/components/logger/logger.cpp index 97ad4c2cb9..11c0733701 100644 --- a/esphome/components/logger/logger.cpp +++ b/esphome/components/logger/logger.cpp @@ -221,7 +221,7 @@ UARTSelection Logger::get_uart() const { return this->uart_; } void Logger::add_on_log_callback(std::function &&callback) { this->log_callback_.add(std::move(callback)); } -float Logger::get_setup_priority() const { return setup_priority::HARDWARE - 1.0f; } +float Logger::get_setup_priority() const { return setup_priority::BUS + 500.0f; } const char *const LOG_LEVELS[] = {"NONE", "ERROR", "WARN", "INFO", "CONFIG", "DEBUG", "VERBOSE", "VERY_VERBOSE"}; #ifdef USE_ESP32 const char *const UART_SELECTIONS[] = {"UART0", "UART1", "UART2"}; diff --git a/esphome/components/web_server/__init__.py b/esphome/components/web_server/__init__.py index 61b1fa5ad6..dc652e0312 100644 --- a/esphome/components/web_server/__init__.py +++ b/esphome/components/web_server/__init__.py @@ -54,6 +54,8 @@ async def to_code(config): var = cg.new_Pvariable(config[CONF_ID], paren) await cg.register_component(var, config) + cg.add_define("USE_WEBSERVER") + cg.add(paren.set_port(config[CONF_PORT])) cg.add_define("WEBSERVER_PORT", config[CONF_PORT]) cg.add_define("USE_WEBSERVER") diff --git a/esphome/components/wifi/__init__.py b/esphome/components/wifi/__init__.py index faf3cca280..7a9319f5e0 100644 --- a/esphome/components/wifi/__init__.py +++ b/esphome/components/wifi/__init__.py @@ -140,7 +140,8 @@ def final_validate(config): has_sta = bool(config.get(CONF_NETWORKS, True)) has_ap = CONF_AP in config has_improv = "esp32_improv" in fv.full_config.get() - if (not has_sta) and (not has_ap) and (not has_improv): + has_improv_serial = "improv_serial" in fv.full_config.get() + if not (has_sta or has_ap or has_improv or has_improv_serial): raise cv.Invalid( "Please specify at least an SSID or an Access Point to create." ) diff --git a/esphome/components/wifi/wifi_component.cpp b/esphome/components/wifi/wifi_component.cpp index 703afa99bc..36944e3633 100644 --- a/esphome/components/wifi/wifi_component.cpp +++ b/esphome/components/wifi/wifi_component.cpp @@ -239,8 +239,6 @@ void WiFiComponent::save_wifi_sta(const std::string &ssid, const std::string &pa sta.set_ssid(ssid); sta.set_password(password); this->set_sta(sta); - - this->start_scanning(); } void WiFiComponent::start_connecting(const WiFiAP &ap, bool two) { diff --git a/esphome/core/defines.h b/esphome/core/defines.h index dc07bde196..94fac73906 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -9,6 +9,7 @@ #define ESPHOME_BOARD "dummy_board" #define ESPHOME_PROJECT_NAME "dummy project" #define ESPHOME_PROJECT_VERSION "v2" +#define ESPHOME_VARIANT "ESP32" // Feature flags #define USE_API diff --git a/script/ci-custom.py b/script/ci-custom.py index 8e9ca487a6..89550afd3d 100755 --- a/script/ci-custom.py +++ b/script/ci-custom.py @@ -263,7 +263,11 @@ def highlight(s): @lint_re_check( r"^#define\s+([a-zA-Z0-9_]+)\s+([0-9bx]+)" + CPP_RE_EOL, include=cpp_include, - exclude=["esphome/core/log.h", "esphome/components/socket/headers.h"], + exclude=[ + "esphome/core/log.h", + "esphome/components/socket/headers.h", + "esphome/core/defines.h", + ], ) def lint_no_defines(fname, match): s = highlight( diff --git a/tests/test3.yaml b/tests/test3.yaml index 0b7f1ad71e..cf80c06aa8 100644 --- a/tests/test3.yaml +++ b/tests/test3.yaml @@ -268,6 +268,8 @@ logger: level: DEBUG esp8266_store_log_strings_in_flash: true +improv_serial: + deep_sleep: run_duration: 20s sleep_duration: 50s