Implement Improv via Serial component (#2423)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Jesse Hills 2021-11-11 08:55:45 +13:00 committed by GitHub
parent 0bdb48bcac
commit 5ff7c8418c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 391 additions and 14 deletions

View file

@ -73,6 +73,7 @@ esphome/components/homeassistant/* @OttoWinter
esphome/components/hrxl_maxsonar_wr/* @netmikey
esphome/components/i2c/* @esphome/core
esphome/components/improv/* @jesserockz
esphome/components/improv_serial/* @esphome/core
esphome/components/inkbird_ibsth1_mini/* @fkirill
esphome/components/inkplate6/* @jesserockz
esphome/components/integration/* @OttoWinter

View file

@ -67,6 +67,7 @@ void CaptivePortal::handle_wifisave(AsyncWebServerRequest *request) {
ESP_LOGI(TAG, " SSID='%s'", ssid.c_str());
ESP_LOGI(TAG, " Password=" LOG_SECRET("'%s'"), psk.c_str());
wifi::global_wifi_component->save_wifi_sta(ssid, psk);
wifi::global_wifi_component->start_scanning();
request->redirect("/?save=true");
}

View file

@ -28,6 +28,7 @@ from .const import ( # noqa
KEY_SDKCONFIG_OPTIONS,
KEY_VARIANT,
VARIANT_ESP32C3,
VARIANT_FRIENDLY,
VARIANTS,
)
from .boards import BOARD_TO_VARIANT
@ -287,6 +288,7 @@ async def to_code(config):
cg.add_build_flag("-DUSE_ESP32")
cg.add_define("ESPHOME_BOARD", config[CONF_BOARD])
cg.add_build_flag(f"-DUSE_ESP32_VARIANT_{config[CONF_VARIANT]}")
cg.add_define("ESPHOME_VARIANT", VARIANT_FRIENDLY[config[CONF_VARIANT]])
cg.add_platformio_option("lib_ldf_mode", "off")

View file

@ -18,4 +18,12 @@ VARIANTS = [
VARIANT_ESP32H2,
]
VARIANT_FRIENDLY = {
VARIANT_ESP32: "ESP32",
VARIANT_ESP32S2: "ESP32-S2",
VARIANT_ESP32S3: "ESP32-S3",
VARIANT_ESP32C3: "ESP32-C3",
VARIANT_ESP32H2: "ESP32-H2",
}
esp32_ns = cg.esphome_ns.namespace("esp32")

View file

@ -156,6 +156,7 @@ async def to_code(config):
cg.add_platformio_option("board", config[CONF_BOARD])
cg.add_build_flag("-DUSE_ESP8266")
cg.add_define("ESPHOME_BOARD", config[CONF_BOARD])
cg.add_define("ESPHOME_VARIANT", "ESP8266")
conf = config[CONF_FRAMEWORK]
cg.add_platformio_option("framework", "arduino")

View file

@ -7,11 +7,13 @@ ImprovCommand parse_improv_data(const std::vector<uint8_t> &data) {
}
ImprovCommand parse_improv_data(const uint8_t *data, size_t length) {
ImprovCommand improv_command;
Command command = (Command) data[0];
uint8_t data_length = data[1];
if (data_length != length - 3) {
return {.command = UNKNOWN};
improv_command.command = UNKNOWN;
return improv_command;
}
uint8_t checksum = data[length - 1];
@ -22,7 +24,8 @@ ImprovCommand parse_improv_data(const uint8_t *data, size_t length) {
}
if ((uint8_t) calculated_checksum != checksum) {
return {.command = BAD_CHECKSUM};
improv_command.command = BAD_CHECKSUM;
return improv_command;
}
if (command == WIFI_SETTINGS) {
@ -39,9 +42,8 @@ ImprovCommand parse_improv_data(const uint8_t *data, size_t length) {
return {.command = command, .ssid = ssid, .password = password};
}
return {
.command = command,
};
improv_command.command = command;
return improv_command;
}
std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::string> &datum) {

View file

@ -1,8 +1,8 @@
#pragma once
#ifdef USE_ARDUINO
#ifdef ARDUINO
#include "WString.h"
#endif // USE_ARDUINO
#endif // ARDUINO
#include <cstdint>
#include <string>
@ -38,6 +38,8 @@ enum Command : uint8_t {
UNKNOWN = 0x00,
WIFI_SETTINGS = 0x01,
IDENTIFY = 0x02,
GET_CURRENT_STATE = 0x02,
GET_DEVICE_INFO = 0x03,
BAD_CHECKSUM = 0xFF,
};
@ -53,8 +55,8 @@ ImprovCommand parse_improv_data(const std::vector<uint8_t> &data);
ImprovCommand parse_improv_data(const uint8_t *data, size_t length);
std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::string> &datum);
#ifdef USE_ARDUINO
#ifdef ARDUINO
std::vector<uint8_t> build_rpc_response(Command command, const std::vector<String> &datum);
#endif // USE_ARDUINO
#endif // ARDUINO
} // namespace improv

View file

@ -0,0 +1,33 @@
from esphome.const import CONF_BAUD_RATE, CONF_ID, CONF_LOGGER
import esphome.codegen as cg
import esphome.config_validation as cv
import esphome.final_validate as fv
CODEOWNERS = ["@esphome/core"]
DEPENDENCIES = ["logger", "wifi"]
AUTO_LOAD = ["improv"]
improv_serial_ns = cg.esphome_ns.namespace("improv_serial")
ImprovSerialComponent = improv_serial_ns.class_("ImprovSerialComponent", cg.Component)
CONFIG_SCHEMA = cv.Schema(
{
cv.GenerateID(): cv.declare_id(ImprovSerialComponent),
}
).extend(cv.COMPONENT_SCHEMA)
def validate_logger_baud_rate(config):
logger_conf = fv.full_config.get()[CONF_LOGGER]
if logger_conf[CONF_BAUD_RATE] == 0:
raise cv.Invalid("improv_serial requires the logger baud_rate to be not 0")
return config
FINAL_VALIDATE_SCHEMA = validate_logger_baud_rate
async def to_code(config):
var = cg.new_Pvariable(config[CONF_ID])
await cg.register_component(var, config)

View file

@ -0,0 +1,250 @@
#include "improv_serial_component.h"
#include "esphome/core/application.h"
#include "esphome/core/defines.h"
#include "esphome/core/hal.h"
#include "esphome/core/log.h"
#include "esphome/core/version.h"
#include "esphome/components/logger/logger.h"
namespace esphome {
namespace improv_serial {
static const char *const TAG = "improv_serial";
void ImprovSerialComponent::setup() {
global_improv_serial_component = this;
#ifdef USE_ARDUINO
this->hw_serial_ = logger::global_logger->get_hw_serial();
#endif
#ifdef USE_ESP_IDF
this->uart_num_ = logger::global_logger->get_uart_num();
#endif
if (wifi::global_wifi_component->has_sta()) {
this->state_ = improv::STATE_PROVISIONED;
}
}
void ImprovSerialComponent::dump_config() { ESP_LOGCONFIG(TAG, "Improv Serial:"); }
int ImprovSerialComponent::available_() {
#ifdef USE_ARDUINO
return this->hw_serial_->available();
#endif
#ifdef USE_ESP_IDF
size_t available;
uart_get_buffered_data_len(this->uart_num_, &available);
return available;
#endif
}
uint8_t ImprovSerialComponent::read_byte_() {
uint8_t data;
#ifdef USE_ARDUINO
this->hw_serial_->readBytes(&data, 1);
#endif
#ifdef USE_ESP_IDF
uart_read_bytes(this->uart_num_, &data, 1, 20 / portTICK_RATE_MS);
#endif
return data;
}
void ImprovSerialComponent::write_data_(std::vector<uint8_t> &data) {
data.push_back('\n');
#ifdef USE_ARDUINO
this->hw_serial_->write(data.data(), data.size());
#endif
#ifdef USE_ESP_IDF
uart_write_bytes(this->uart_num_, data.data(), data.size());
#endif
}
void ImprovSerialComponent::loop() {
const uint32_t now = millis();
if (now - this->last_read_byte_ > 50) {
this->rx_buffer_.clear();
this->last_read_byte_ = now;
}
while (this->available_()) {
uint8_t byte = this->read_byte_();
if (this->parse_improv_serial_byte_(byte)) {
this->last_read_byte_ = now;
} else {
this->rx_buffer_.clear();
}
}
if (this->state_ == improv::STATE_PROVISIONING) {
if (wifi::global_wifi_component->is_connected()) {
wifi::global_wifi_component->save_wifi_sta(this->connecting_sta_.get_ssid(),
this->connecting_sta_.get_password());
this->connecting_sta_ = {};
this->cancel_timeout("wifi-connect-timeout");
this->set_state_(improv::STATE_PROVISIONED);
std::vector<uint8_t> url = this->build_rpc_settings_response_(improv::WIFI_SETTINGS);
this->send_response_(url);
}
}
}
std::vector<uint8_t> ImprovSerialComponent::build_rpc_settings_response_(improv::Command command) {
std::string url = "https://my.home-assistant.io/redirect/config_flow_start?domain=esphome";
std::vector<std::string> urls = {url};
#ifdef USE_WEBSERVER
auto ip = wifi::global_wifi_component->wifi_sta_ip();
std::string webserver_url = "http://" + ip.str() + ":" + to_string(WEBSERVER_PORT);
urls.push_back(webserver_url);
#endif
std::vector<uint8_t> data = improv::build_rpc_response(command, urls);
return data;
}
std::vector<uint8_t> ImprovSerialComponent::build_version_info_() {
std::vector<std::string> infos = {"ESPHome", ESPHOME_VERSION, ESPHOME_VARIANT, App.get_name()};
std::vector<uint8_t> data = improv::build_rpc_response(improv::GET_DEVICE_INFO, infos);
return data;
};
bool ImprovSerialComponent::parse_improv_serial_byte_(uint8_t byte) {
size_t at = this->rx_buffer_.size();
this->rx_buffer_.push_back(byte);
ESP_LOGD(TAG, "Improv Serial byte: 0x%02X", byte);
const uint8_t *raw = &this->rx_buffer_[0];
if (at == 0)
return byte == 'I';
if (at == 1)
return byte == 'M';
if (at == 2)
return byte == 'P';
if (at == 3)
return byte == 'R';
if (at == 4)
return byte == 'O';
if (at == 5)
return byte == 'V';
if (at == 6)
return byte == IMPROV_SERIAL_VERSION;
if (at == 7)
return true;
uint8_t type = raw[7];
if (at == 8)
return true;
uint8_t data_len = raw[8];
if (at < 8 + data_len)
return true;
if (at == 8 + data_len) {
if (type == TYPE_RPC) {
this->set_error_(improv::ERROR_NONE);
auto command = improv::parse_improv_data(&raw[9], data_len);
return this->parse_improv_payload_(command);
}
}
return true;
}
bool ImprovSerialComponent::parse_improv_payload_(improv::ImprovCommand &command) {
switch (command.command) {
case improv::BAD_CHECKSUM:
ESP_LOGW(TAG, "Error decoding Improv payload");
this->set_error_(improv::ERROR_INVALID_RPC);
return false;
case improv::WIFI_SETTINGS: {
wifi::WiFiAP sta{};
sta.set_ssid(command.ssid);
sta.set_password(command.password);
this->connecting_sta_ = sta;
wifi::global_wifi_component->set_sta(sta);
wifi::global_wifi_component->start_scanning();
this->set_state_(improv::STATE_PROVISIONING);
ESP_LOGD(TAG, "Received Improv wifi settings ssid=%s, password=" LOG_SECRET("%s"), command.ssid.c_str(),
command.password.c_str());
auto f = std::bind(&ImprovSerialComponent::on_wifi_connect_timeout_, this);
this->set_timeout("wifi-connect-timeout", 30000, f);
return true;
}
case improv::GET_CURRENT_STATE:
this->set_state_(this->state_);
if (this->state_ == improv::STATE_PROVISIONED) {
std::vector<uint8_t> url = this->build_rpc_settings_response_(improv::GET_CURRENT_STATE);
this->send_response_(url);
}
return true;
case improv::GET_DEVICE_INFO: {
std::vector<uint8_t> info = this->build_version_info_();
this->send_response_(info);
return true;
}
default: {
ESP_LOGW(TAG, "Unknown Improv payload");
this->set_error_(improv::ERROR_UNKNOWN_RPC);
return false;
}
}
}
void ImprovSerialComponent::set_state_(improv::State state) {
this->state_ = state;
std::vector<uint8_t> data = {'I', 'M', 'P', 'R', 'O', 'V'};
data.resize(11);
data[6] = IMPROV_SERIAL_VERSION;
data[7] = TYPE_CURRENT_STATE;
data[8] = 1;
data[9] = state;
uint8_t checksum = 0x00;
for (uint8_t d : data)
checksum += d;
data[10] = checksum;
this->write_data_(data);
}
void ImprovSerialComponent::set_error_(improv::Error error) {
std::vector<uint8_t> data = {'I', 'M', 'P', 'R', 'O', 'V'};
data.resize(11);
data[6] = IMPROV_SERIAL_VERSION;
data[7] = TYPE_ERROR_STATE;
data[8] = 1;
data[9] = error;
uint8_t checksum = 0x00;
for (uint8_t d : data)
checksum += d;
data[10] = checksum;
this->write_data_(data);
}
void ImprovSerialComponent::send_response_(std::vector<uint8_t> &response) {
std::vector<uint8_t> data = {'I', 'M', 'P', 'R', 'O', 'V'};
data.resize(9);
data[6] = IMPROV_SERIAL_VERSION;
data[7] = TYPE_RPC_RESPONSE;
data[8] = response.size();
data.insert(data.end(), response.begin(), response.end());
this->write_data_(data);
}
void ImprovSerialComponent::on_wifi_connect_timeout_() {
this->set_error_(improv::ERROR_UNABLE_TO_CONNECT);
this->set_state_(improv::STATE_AUTHORIZED);
ESP_LOGW(TAG, "Timed out trying to connect to given WiFi network");
wifi::global_wifi_component->clear_sta();
}
ImprovSerialComponent *global_improv_serial_component = // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
} // namespace improv_serial
} // namespace esphome

View file

@ -0,0 +1,69 @@
#pragma once
#include "esphome/components/improv/improv.h"
#include "esphome/components/wifi/wifi_component.h"
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/core/helpers.h"
#ifdef USE_ARDUINO
#include <HardwareSerial.h>
#endif
#ifdef USE_ESP_IDF
#include <driver/uart.h>
#endif
namespace esphome {
namespace improv_serial {
enum ImprovSerialType : uint8_t {
TYPE_CURRENT_STATE = 0x01,
TYPE_ERROR_STATE = 0x02,
TYPE_RPC = 0x03,
TYPE_RPC_RESPONSE = 0x04
};
static const uint8_t IMPROV_SERIAL_VERSION = 1;
class ImprovSerialComponent : public Component {
public:
void setup() override;
void loop() override;
void dump_config() override;
float get_setup_priority() const override { return setup_priority::HARDWARE; }
protected:
bool parse_improv_serial_byte_(uint8_t byte);
bool parse_improv_payload_(improv::ImprovCommand &command);
void set_state_(improv::State state);
void set_error_(improv::Error error);
void send_response_(std::vector<uint8_t> &response);
void on_wifi_connect_timeout_();
std::vector<uint8_t> build_rpc_settings_response_(improv::Command command);
std::vector<uint8_t> build_version_info_();
int available_();
uint8_t read_byte_();
void write_data_(std::vector<uint8_t> &data);
#ifdef USE_ARDUINO
HardwareSerial *hw_serial_{nullptr};
#endif
#ifdef USE_ESP_IDF
uart_port_t uart_num_;
#endif
std::vector<uint8_t> rx_buffer_;
uint32_t last_read_byte_{0};
wifi::WiFiAP connecting_sta_;
improv::State state_{improv::STATE_AUTHORIZED};
};
extern ImprovSerialComponent
*global_improv_serial_component; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
} // namespace improv_serial
} // namespace esphome

View file

@ -221,7 +221,7 @@ UARTSelection Logger::get_uart() const { return this->uart_; }
void Logger::add_on_log_callback(std::function<void(int, const char *, const char *)> &&callback) {
this->log_callback_.add(std::move(callback));
}
float Logger::get_setup_priority() const { return setup_priority::HARDWARE - 1.0f; }
float Logger::get_setup_priority() const { return setup_priority::BUS + 500.0f; }
const char *const LOG_LEVELS[] = {"NONE", "ERROR", "WARN", "INFO", "CONFIG", "DEBUG", "VERBOSE", "VERY_VERBOSE"};
#ifdef USE_ESP32
const char *const UART_SELECTIONS[] = {"UART0", "UART1", "UART2"};

View file

@ -54,6 +54,8 @@ async def to_code(config):
var = cg.new_Pvariable(config[CONF_ID], paren)
await cg.register_component(var, config)
cg.add_define("USE_WEBSERVER")
cg.add(paren.set_port(config[CONF_PORT]))
cg.add_define("WEBSERVER_PORT", config[CONF_PORT])
cg.add_define("USE_WEBSERVER")

View file

@ -140,7 +140,8 @@ def final_validate(config):
has_sta = bool(config.get(CONF_NETWORKS, True))
has_ap = CONF_AP in config
has_improv = "esp32_improv" in fv.full_config.get()
if (not has_sta) and (not has_ap) and (not has_improv):
has_improv_serial = "improv_serial" in fv.full_config.get()
if not (has_sta or has_ap or has_improv or has_improv_serial):
raise cv.Invalid(
"Please specify at least an SSID or an Access Point to create."
)

View file

@ -239,8 +239,6 @@ void WiFiComponent::save_wifi_sta(const std::string &ssid, const std::string &pa
sta.set_ssid(ssid);
sta.set_password(password);
this->set_sta(sta);
this->start_scanning();
}
void WiFiComponent::start_connecting(const WiFiAP &ap, bool two) {

View file

@ -9,6 +9,7 @@
#define ESPHOME_BOARD "dummy_board"
#define ESPHOME_PROJECT_NAME "dummy project"
#define ESPHOME_PROJECT_VERSION "v2"
#define ESPHOME_VARIANT "ESP32"
// Feature flags
#define USE_API

View file

@ -263,7 +263,11 @@ def highlight(s):
@lint_re_check(
r"^#define\s+([a-zA-Z0-9_]+)\s+([0-9bx]+)" + CPP_RE_EOL,
include=cpp_include,
exclude=["esphome/core/log.h", "esphome/components/socket/headers.h"],
exclude=[
"esphome/core/log.h",
"esphome/components/socket/headers.h",
"esphome/core/defines.h",
],
)
def lint_no_defines(fname, match):
s = highlight(

View file

@ -268,6 +268,8 @@ logger:
level: DEBUG
esp8266_store_log_strings_in_flash: true
improv_serial:
deep_sleep:
run_duration: 20s
sleep_duration: 50s