Add OTA upload compression for ESP8266 (#2601)

This commit is contained in:
Otto Winter 2021-10-22 13:02:55 +02:00 committed by GitHub
parent c08b21b7cd
commit 0d90ef94ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 42 additions and 16 deletions

View file

@ -12,6 +12,7 @@ class OTABackend {
virtual OTAResponseTypes write(uint8_t *data, size_t len) = 0; virtual OTAResponseTypes write(uint8_t *data, size_t len) = 0;
virtual OTAResponseTypes end() = 0; virtual OTAResponseTypes end() = 0;
virtual void abort() = 0; virtual void abort() = 0;
virtual bool supports_compression() = 0;
}; };
} // namespace ota } // namespace ota

View file

@ -15,6 +15,7 @@ class ArduinoESP32OTABackend : public OTABackend {
OTAResponseTypes write(uint8_t *data, size_t len) override; OTAResponseTypes write(uint8_t *data, size_t len) override;
OTAResponseTypes end() override; OTAResponseTypes end() override;
void abort() override; void abort() override;
bool supports_compression() override { return false; }
}; };
} // namespace ota } // namespace ota

View file

@ -16,6 +16,7 @@ class ArduinoESP8266OTABackend : public OTABackend {
OTAResponseTypes write(uint8_t *data, size_t len) override; OTAResponseTypes write(uint8_t *data, size_t len) override;
OTAResponseTypes end() override; OTAResponseTypes end() override;
void abort() override; void abort() override;
bool supports_compression() override { return true; }
}; };
} // namespace ota } // namespace ota

View file

@ -17,6 +17,7 @@ class IDFOTABackend : public OTABackend {
OTAResponseTypes write(uint8_t *data, size_t len) override; OTAResponseTypes write(uint8_t *data, size_t len) override;
OTAResponseTypes end() override; OTAResponseTypes end() override;
void abort() override; void abort() override;
bool supports_compression() override { return false; }
private: private:
esp_ota_handle_t update_handle_{0}; esp_ota_handle_t update_handle_{0};

View file

@ -104,6 +104,8 @@ void OTAComponent::loop() {
} }
} }
static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01;
void OTAComponent::handle_() { void OTAComponent::handle_() {
OTAResponseTypes error_code = OTA_RESPONSE_ERROR_UNKNOWN; OTAResponseTypes error_code = OTA_RESPONSE_ERROR_UNKNOWN;
bool update_started = false; bool update_started = false;
@ -154,6 +156,8 @@ void OTAComponent::handle_() {
buf[1] = OTA_VERSION_1_0; buf[1] = OTA_VERSION_1_0;
this->writeall_(buf, 2); this->writeall_(buf, 2);
backend = make_ota_backend();
// Read features - 1 byte // Read features - 1 byte
if (!this->readall_(buf, 1)) { if (!this->readall_(buf, 1)) {
ESP_LOGW(TAG, "Reading features failed!"); ESP_LOGW(TAG, "Reading features failed!");
@ -164,6 +168,10 @@ void OTAComponent::handle_() {
// Acknowledge header - 1 byte // Acknowledge header - 1 byte
buf[0] = OTA_RESPONSE_HEADER_OK; buf[0] = OTA_RESPONSE_HEADER_OK;
if ((ota_features & FEATURE_SUPPORTS_COMPRESSION) != 0 && backend->supports_compression()) {
buf[0] = OTA_RESPONSE_SUPPORTS_COMPRESSION;
}
this->writeall_(buf, 1); this->writeall_(buf, 1);
#ifdef USE_OTA_PASSWORD #ifdef USE_OTA_PASSWORD
@ -241,7 +249,6 @@ void OTAComponent::handle_() {
} }
ESP_LOGV(TAG, "OTA size is %u bytes", ota_size); ESP_LOGV(TAG, "OTA size is %u bytes", ota_size);
backend = make_ota_backend();
error_code = backend->begin(ota_size); error_code = backend->begin(ota_size);
if (error_code != OTA_RESPONSE_OK) if (error_code != OTA_RESPONSE_OK)
goto error; goto error;

View file

@ -19,6 +19,7 @@ enum OTAResponseTypes {
OTA_RESPONSE_BIN_MD5_OK = 67, OTA_RESPONSE_BIN_MD5_OK = 67,
OTA_RESPONSE_RECEIVE_OK = 68, OTA_RESPONSE_RECEIVE_OK = 68,
OTA_RESPONSE_UPDATE_END_OK = 69, OTA_RESPONSE_UPDATE_END_OK = 69,
OTA_RESPONSE_SUPPORTS_COMPRESSION = 70,
OTA_RESPONSE_ERROR_MAGIC = 128, OTA_RESPONSE_ERROR_MAGIC = 128,
OTA_RESPONSE_ERROR_UPDATE_PREPARE = 129, OTA_RESPONSE_ERROR_UPDATE_PREPARE = 129,

View file

@ -4,6 +4,7 @@ import random
import socket import socket
import sys import sys
import time import time
import gzip
from esphome.core import EsphomeError from esphome.core import EsphomeError
from esphome.helpers import is_ip_address, resolve_ip_address from esphome.helpers import is_ip_address, resolve_ip_address
@ -17,6 +18,7 @@ RESPONSE_UPDATE_PREPARE_OK = 66
RESPONSE_BIN_MD5_OK = 67 RESPONSE_BIN_MD5_OK = 67
RESPONSE_RECEIVE_OK = 68 RESPONSE_RECEIVE_OK = 68
RESPONSE_UPDATE_END_OK = 69 RESPONSE_UPDATE_END_OK = 69
RESPONSE_SUPPORTS_COMPRESSION = 70
RESPONSE_ERROR_MAGIC = 128 RESPONSE_ERROR_MAGIC = 128
RESPONSE_ERROR_UPDATE_PREPARE = 129 RESPONSE_ERROR_UPDATE_PREPARE = 129
@ -34,6 +36,8 @@ OTA_VERSION_1_0 = 1
MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45] MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45]
FEATURE_SUPPORTS_COMPRESSION = 0x01
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -170,11 +174,9 @@ def send_check(sock, data, msg):
def perform_ota(sock, password, file_handle, filename): def perform_ota(sock, password, file_handle, filename):
file_md5 = hashlib.md5(file_handle.read()).hexdigest() file_contents = file_handle.read()
file_size = file_handle.tell() file_size = len(file_contents)
_LOGGER.info("Uploading %s (%s bytes)", filename, file_size) _LOGGER.info("Uploading %s (%s bytes)", filename, file_size)
file_handle.seek(0)
_LOGGER.debug("MD5 of binary is %s", file_md5)
# Enable nodelay, we need it for phase 1 # Enable nodelay, we need it for phase 1
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@ -185,8 +187,16 @@ def perform_ota(sock, password, file_handle, filename):
raise OTAError(f"Unsupported OTA version {version}") raise OTAError(f"Unsupported OTA version {version}")
# Features # Features
send_check(sock, 0x00, "features") send_check(sock, FEATURE_SUPPORTS_COMPRESSION, "features")
receive_exactly(sock, 1, "features", RESPONSE_HEADER_OK) features = receive_exactly(
sock, 1, "features", [RESPONSE_HEADER_OK, RESPONSE_SUPPORTS_COMPRESSION]
)[0]
if features == RESPONSE_SUPPORTS_COMPRESSION:
upload_contents = gzip.compress(file_contents, compresslevel=9)
_LOGGER.info("Compressed to %s bytes", len(upload_contents))
else:
upload_contents = file_contents
(auth,) = receive_exactly( (auth,) = receive_exactly(
sock, 1, "auth", [RESPONSE_REQUEST_AUTH, RESPONSE_AUTH_OK] sock, 1, "auth", [RESPONSE_REQUEST_AUTH, RESPONSE_AUTH_OK]
@ -213,16 +223,20 @@ def perform_ota(sock, password, file_handle, filename):
send_check(sock, result, "auth result") send_check(sock, result, "auth result")
receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK) receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK)
file_size_encoded = [ upload_size = len(upload_contents)
(file_size >> 24) & 0xFF, upload_size_encoded = [
(file_size >> 16) & 0xFF, (upload_size >> 24) & 0xFF,
(file_size >> 8) & 0xFF, (upload_size >> 16) & 0xFF,
(file_size >> 0) & 0xFF, (upload_size >> 8) & 0xFF,
(upload_size >> 0) & 0xFF,
] ]
send_check(sock, file_size_encoded, "binary size") send_check(sock, upload_size_encoded, "binary size")
receive_exactly(sock, 1, "binary size", RESPONSE_UPDATE_PREPARE_OK) receive_exactly(sock, 1, "binary size", RESPONSE_UPDATE_PREPARE_OK)
send_check(sock, file_md5, "file checksum") upload_md5 = hashlib.md5(upload_contents).hexdigest()
_LOGGER.debug("MD5 of upload is %s", upload_md5)
send_check(sock, upload_md5, "file checksum")
receive_exactly(sock, 1, "file checksum", RESPONSE_BIN_MD5_OK) receive_exactly(sock, 1, "file checksum", RESPONSE_BIN_MD5_OK)
# Disable nodelay for transfer # Disable nodelay for transfer
@ -236,7 +250,7 @@ def perform_ota(sock, password, file_handle, filename):
offset = 0 offset = 0
progress = ProgressBar() progress = ProgressBar()
while True: while True:
chunk = file_handle.read(1024) chunk = upload_contents[offset : offset + 1024]
if not chunk: if not chunk:
break break
offset += len(chunk) offset += len(chunk)
@ -247,7 +261,7 @@ def perform_ota(sock, password, file_handle, filename):
sys.stderr.write("\n") sys.stderr.write("\n")
raise OTAError(f"Error sending data: {err}") from err raise OTAError(f"Error sending data: {err}") from err
progress.update(offset / float(file_size)) progress.update(offset / upload_size)
progress.done() progress.done()
# Enable nodelay for last checks # Enable nodelay for last checks