diff --git a/esphome/components/improv/improv.cpp b/esphome/components/improv/improv.cpp index 94068bc626..759962b51a 100644 --- a/esphome/components/improv/improv.cpp +++ b/esphome/components/improv/improv.cpp @@ -2,30 +2,32 @@ namespace improv { -ImprovCommand parse_improv_data(const std::vector &data) { - return parse_improv_data(data.data(), data.size()); +ImprovCommand parse_improv_data(const std::vector &data, bool check_checksum) { + return parse_improv_data(data.data(), data.size(), check_checksum); } -ImprovCommand parse_improv_data(const uint8_t *data, size_t length) { +ImprovCommand parse_improv_data(const uint8_t *data, size_t length, bool check_checksum) { ImprovCommand improv_command; Command command = (Command) data[0]; uint8_t data_length = data[1]; - if (data_length != length - 3) { + if (data_length != length - 2 - check_checksum) { improv_command.command = UNKNOWN; return improv_command; } - uint8_t checksum = data[length - 1]; + if (check_checksum) { + uint8_t checksum = data[length - 1]; - uint32_t calculated_checksum = 0; - for (uint8_t i = 0; i < length - 1; i++) { - calculated_checksum += data[i]; - } + uint32_t calculated_checksum = 0; + for (uint8_t i = 0; i < length - 1; i++) { + calculated_checksum += data[i]; + } - if ((uint8_t) calculated_checksum != checksum) { - improv_command.command = BAD_CHECKSUM; - return improv_command; + if ((uint8_t) calculated_checksum != checksum) { + improv_command.command = BAD_CHECKSUM; + return improv_command; + } } if (command == WIFI_SETTINGS) { @@ -46,7 +48,7 @@ ImprovCommand parse_improv_data(const uint8_t *data, size_t length) { return improv_command; } -std::vector build_rpc_response(Command command, const std::vector &datum) { +std::vector build_rpc_response(Command command, const std::vector &datum, bool add_checksum) { std::vector out; uint32_t length = 0; out.push_back(command); @@ -58,17 +60,19 @@ std::vector build_rpc_response(Command command, const std::vector build_rpc_response(Command command, const std::vector &datum) { +#ifdef ARDUINO +std::vector build_rpc_response(Command command, const std::vector &datum, bool add_checksum) { std::vector out; uint32_t length = 0; out.push_back(command); @@ -80,14 +84,16 @@ std::vector build_rpc_response(Command command, const std::vector &data); -ImprovCommand parse_improv_data(const uint8_t *data, size_t length); +ImprovCommand parse_improv_data(const std::vector &data, bool check_checksum = true); +ImprovCommand parse_improv_data(const uint8_t *data, size_t length, bool check_checksum = true); -std::vector build_rpc_response(Command command, const std::vector &datum); +std::vector build_rpc_response(Command command, const std::vector &datum, + bool add_checksum = true); #ifdef ARDUINO -std::vector build_rpc_response(Command command, const std::vector &datum); +std::vector build_rpc_response(Command command, const std::vector &datum, bool add_checksum = true); #endif // ARDUINO } // namespace improv diff --git a/esphome/components/improv_serial/improv_serial_component.cpp b/esphome/components/improv_serial/improv_serial_component.cpp index abbb76ab11..a9a7467125 100644 --- a/esphome/components/improv_serial/improv_serial_component.cpp +++ b/esphome/components/improv_serial/improv_serial_component.cpp @@ -98,13 +98,13 @@ std::vector ImprovSerialComponent::build_rpc_settings_response_(improv: std::string webserver_url = "http://" + ip.str() + ":" + to_string(WEBSERVER_PORT); urls.push_back(webserver_url); #endif - std::vector data = improv::build_rpc_response(command, urls); + std::vector data = improv::build_rpc_response(command, urls, false); return data; } std::vector ImprovSerialComponent::build_version_info_() { std::vector infos = {"ESPHome", ESPHOME_VERSION, ESPHOME_VARIANT, App.get_name()}; - std::vector data = improv::build_rpc_response(improv::GET_DEVICE_INFO, infos); + std::vector data = improv::build_rpc_response(improv::GET_DEVICE_INFO, infos, false); return data; }; @@ -140,22 +140,33 @@ bool ImprovSerialComponent::parse_improv_serial_byte_(uint8_t byte) { if (at < 8 + data_len) return true; - if (at == 8 + data_len) { + if (at == 8 + data_len) + return true; + + if (at == 8 + data_len + 1) { + uint8_t checksum = 0x00; + for (uint8_t i = 0; i < at; i++) + checksum += raw[i]; + + if (checksum != byte) { + ESP_LOGW(TAG, "Error decoding Improv payload"); + this->set_error_(improv::ERROR_INVALID_RPC); + return false; + } + if (type == TYPE_RPC) { this->set_error_(improv::ERROR_NONE); - auto command = improv::parse_improv_data(&raw[9], data_len); + auto command = improv::parse_improv_data(&raw[9], data_len, false); return this->parse_improv_payload_(command); } } - return true; + + // If we got here then the command coming is is improv, but not an RPC command + return false; } bool ImprovSerialComponent::parse_improv_payload_(improv::ImprovCommand &command) { switch (command.command) { - case improv::BAD_CHECKSUM: - ESP_LOGW(TAG, "Error decoding Improv payload"); - this->set_error_(improv::ERROR_INVALID_RPC); - return false; case improv::WIFI_SETTINGS: { wifi::WiFiAP sta{}; sta.set_ssid(command.ssid); @@ -232,6 +243,12 @@ void ImprovSerialComponent::send_response_(std::vector &response) { data[7] = TYPE_RPC_RESPONSE; data[8] = response.size(); data.insert(data.end(), response.begin(), response.end()); + + uint8_t checksum = 0x00; + for (uint8_t d : data) + checksum += d; + data.push_back(checksum); + this->write_data_(data); } diff --git a/esphome/const.py b/esphome/const.py index e2bfc81208..48b1e8aa96 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -1,6 +1,6 @@ """Constants used by esphome.""" -__version__ = "2021.11.0b6" +__version__ = "2021.11.0b7" ALLOWED_NAME_CHARS = "abcdefghijklmnopqrstuvwxyz0123456789-_"