From 644fd263a3e329a39599a28443a16199f48f5ed0 Mon Sep 17 00:00:00 2001 From: Tomasz Duda Date: Fri, 19 Jul 2024 09:28:46 +0200 Subject: [PATCH] add mcuboot support --- esphome/__main__.py | 83 +++++++++++++++++-- esphome/const.py | 2 + esphome/zephyr_tools.py | 173 ++++++++++++++++++++++++++++++++++++++++ requirements.txt | 4 + 4 files changed, 254 insertions(+), 8 deletions(-) create mode 100644 esphome/zephyr_tools.py diff --git a/esphome/__main__.py b/esphome/__main__.py index b13f96daf7..20fbf2dbb4 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -6,6 +6,7 @@ import os import re import sys import time +import asyncio from datetime import datetime import argcomplete @@ -36,6 +37,7 @@ from esphome.const import ( PLATFORM_RP2040, PLATFORM_RTL87XX, SECRETS_FILES, + PLATFORM_NRF52, ) from esphome.core import CORE, EsphomeError, coroutine from esphome.helpers import indent, is_ip_address @@ -47,6 +49,13 @@ from esphome.util import ( get_serial_ports, ) from esphome.log import color, setup_log, Fore +from .zephyr_tools import ( + logger_scan, + logger_connect, + smpmgr_scan, + smpmgr_upload, + is_mac_address, +) _LOGGER = logging.getLogger(__name__) @@ -86,19 +95,59 @@ def choose_prompt(options, purpose: str = None): def choose_upload_log_host( default, check_default, show_ota, show_mqtt, show_api, purpose: str = None ): + try: + mcuboot = CORE.config["nrf52"]["bootloader"] == "mcuboot" + except KeyError: + mcuboot = False + try: + ble_logger = CORE.config["zephyr_ble_nus"]["log"] + except KeyError: + ble_logger = False + ota = "ota" in CORE.config options = [] + prefix = "" + if mcuboot and show_ota and ota: + prefix = "mcumgr " for port in get_serial_ports(): - options.append((f"{port.path} ({port.description})", port.path)) + options.append( + (f"{prefix}{port.path} ({port.description})", f"{prefix}{port.path}") + ) if default == "SERIAL": return choose_prompt(options, purpose=purpose) - if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config): - options.append((f"Over The Air ({CORE.address})", CORE.address)) - if default == "OTA": - return CORE.address + if default == "PYOCD": + if not mcuboot: + raise EsphomeError("PYOCD for adafruit is not implemented") + options = [("pyocd", "PYOCD")] + return choose_prompt(options, purpose=purpose) + if not mcuboot: + if (show_ota and ota) or (show_api and "api" in CORE.config): + options.append((f"Over The Air ({CORE.address})", CORE.address)) + if default == "OTA": + return CORE.address + elif show_ota and ota: + if default: + options.append((f"OTA over Bluetooth LE ({default})", f"mcumgr {default}")) + return choose_prompt(options, purpose=purpose) + ble_devices = asyncio.run(smpmgr_scan(CORE.config["esphome"]["name"])) + if len(ble_devices) == 0: + _LOGGER.warning("No OTA over Bluetooth LE service found!") + for device in ble_devices: + options.append( + ( + f"OTA over Bluetooth LE({device.address}) {device.name}", + f"mcumgr {device.address}", + ) + ) if show_mqtt and CONF_MQTT in CORE.config: options.append((f"MQTT ({CORE.config['mqtt'][CONF_BROKER]})", "MQTT")) if default == "OTA": return "MQTT" + if "logging" == purpose and ble_logger and default is None: + ble_device = asyncio.run(logger_scan(CORE.config["esphome"]["name"])) + if ble_device: + options.append((f"Bluetooth LE logger ({ble_device})", ble_device.address)) + else: + _LOGGER.warning("No logger over Bluetooth LE service found!") if default is not None: return default if check_default is not None and check_default in [opt[1] for opt in options]: @@ -111,6 +160,8 @@ def get_port_type(port): return "SERIAL" if port == "MQTT": return "MQTT" + if is_mac_address(port): + return "BLE" return "NETWORK" @@ -289,10 +340,11 @@ def upload_using_esptool(config, port, file): return run_esptool(115200) -def upload_using_platformio(config, port): +def upload_using_platformio(config, port, upload_args=None): from esphome import platformio_api - upload_args = ["-t", "upload", "-t", "nobuild"] + if upload_args is None: + upload_args = ["-t", "upload", "-t", "nobuild"] if port is not None: upload_args += ["--upload-port", port] return platformio_api.run_platformio_cli_run(config, CORE.verbose, *upload_args) @@ -329,7 +381,19 @@ def upload_program(config, args, host): if CORE.target_platform in (PLATFORM_BK72XX, PLATFORM_RTL87XX): return upload_using_platformio(config, host) - return 1 # Unknown target platform + if CORE.target_platform in (PLATFORM_NRF52): + return upload_using_platformio(config, host, ["-t", "upload"]) + + raise EsphomeError(f"Unknown target platform: {CORE.target_platform}") + + if host == "PYOCD": + print(CORE) + return upload_using_platformio(config, host, ["-t", "flash_pyocd"]) + if host.startswith("mcumgr"): + firmware = os.path.abspath( + CORE.relative_pioenvs_path(CORE.name, "zephyr", "app_update.bin") + ) + return asyncio.run(smpmgr_upload(config, host.split(" ")[1], firmware)) ota_conf = {} for ota_item in config.get(CONF_OTA, []): @@ -389,6 +453,9 @@ def show_logs(config, args, port): config, args.topic, args.username, args.password, args.client_id ) + if get_port_type(port) == "BLE": + return asyncio.run(logger_connect(port)) + raise EsphomeError("No remote or local logging method configured (api/mqtt/logger)") diff --git a/esphome/const.py b/esphome/const.py index faf6ce19fa..7237e697fd 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -14,6 +14,7 @@ PLATFORM_HOST = "host" PLATFORM_BK72XX = "bk72xx" PLATFORM_RTL87XX = "rtl87xx" PLATFORM_LIBRETINY_OLDSTYLE = "libretiny" +PLATFORM_NRF52 = "nrf52" TARGET_PLATFORMS = [ PLATFORM_ESP32, @@ -23,6 +24,7 @@ TARGET_PLATFORMS = [ PLATFORM_BK72XX, PLATFORM_RTL87XX, PLATFORM_LIBRETINY_OLDSTYLE, + PLATFORM_NRF52, ] SOURCE_FILE_EXTENSIONS = {".cpp", ".hpp", ".h", ".c", ".tcc", ".ino"} diff --git a/esphome/zephyr_tools.py b/esphome/zephyr_tools.py new file mode 100644 index 0000000000..d32418e403 --- /dev/null +++ b/esphome/zephyr_tools.py @@ -0,0 +1,173 @@ +import time +import asyncio +import logging +import re +from typing import Final +from rich.pretty import pprint +from bleak import BleakScanner, BleakClient +from bleak.exc import BleakDeviceNotFoundError, BleakDBusError +from smpclient.transport.ble import SMPBLETransport +from smpclient.transport import SMPTransportDisconnected +from smpclient.transport.serial import SMPSerialTransport +from smpclient import SMPClient +from smpclient.mcuboot import IMAGE_TLV, ImageInfo, TLVNotFound, MCUBootImageError +from smpclient.requests.image_management import ImageStatesRead, ImageStatesWrite +from smpclient.requests.os_management import ResetWrite +from smpclient.generics import error, success +from smp.exceptions import SMPBadStartDelimiter +from esphome.espota2 import ProgressBar + +SMP_SERVICE_UUID = "8D53DC1D-1DB7-4CD3-868B-8A527460AA84" +NUS_SERVICE_UUID = "6E400001-B5A3-F393-E0A9-E50E24DCCA9E" +NUS_TX_CHAR_UUID = "6E400003-B5A3-F393-E0A9-E50E24DCCA9E" +MAC_ADDRESS_PATTERN: Final = re.compile( + r"([0-9A-F]{2}[:]){5}[0-9A-F]{2}$", flags=re.IGNORECASE +) + +_LOGGER = logging.getLogger(__name__) + + +def is_mac_address(value): + return MAC_ADDRESS_PATTERN.match(value) + + +async def logger_scan(name): + _LOGGER.info("Scanning bluetooth for %s...", name) + device = await BleakScanner.find_device_by_name(name) + return device + + +async def logger_connect(host): + disconnected_event = asyncio.Event() + + def handle_disconnect(client): + disconnected_event.set() + + def handle_rx(_, data: bytearray): + print(data.decode("utf-8"), end="") + + _LOGGER.info("Connecting %s...", host) + async with BleakClient(host, disconnected_callback=handle_disconnect) as client: + _LOGGER.info("Connected %s...", host) + try: + await client.start_notify(NUS_TX_CHAR_UUID, handle_rx) + except BleakDBusError as e: + _LOGGER.error("Bluetooth LE logger: %s", e) + disconnected_event.set() + await disconnected_event.wait() + + +async def smpmgr_scan(name): + _LOGGER.info("Scanning bluetooth for %s...", name) + devices = [] + for device in await BleakScanner.discover(service_uuids=[SMP_SERVICE_UUID]): + if device.name == name: + devices += [device] + return devices + + +def get_image_tlv_sha256(file): + _LOGGER.info("Checking image: %s", str(file)) + try: + image_info = ImageInfo.load_file(str(file)) + pprint(image_info.header) + _LOGGER.debug(str(image_info)) + except MCUBootImageError as e: + _LOGGER.error("Inspection of FW image failed: %s", e) + return None + + try: + image_tlv_sha256 = image_info.get_tlv(IMAGE_TLV.SHA256) + _LOGGER.debug("IMAGE_TLV_SHA256: %s", image_tlv_sha256) + except TLVNotFound: + _LOGGER.error("Could not find IMAGE_TLV_SHA256 in image.") + return None + return image_tlv_sha256.value + + +async def smpmgr_upload(config, host, firmware): + try: + return await smpmgr_upload_(config, host, firmware) + except SMPTransportDisconnected: + _LOGGER.error("%s was disconnected.", host) + return 1 + + +async def smpmgr_upload_(config, host, firmware): + image_tlv_sha256 = get_image_tlv_sha256(firmware) + if image_tlv_sha256 is None: + return 1 + + if is_mac_address(host): + smp_client = SMPClient(SMPBLETransport(), host) + else: + smp_client = SMPClient(SMPSerialTransport(), host) + + _LOGGER.info("Connecting %s...", host) + try: + await smp_client.connect() + except BleakDeviceNotFoundError: + _LOGGER.error("Device %s not found", host) + return 1 + + _LOGGER.info("Connected %s...", host) + + try: + image_state = await smp_client.request(ImageStatesRead(), 2.5) + except SMPBadStartDelimiter as e: + _LOGGER.error("mcumgr is not supported by device (%s)", e) + return 1 + + already_uploaded = False + + if error(image_state): + _LOGGER.error(image_state) + return 1 + if success(image_state): + if len(image_state.images) == 0: + _LOGGER.warning("No images on device!") + for image in image_state.images: + pprint(image) + if image.active and not image.confirmed: + _LOGGER.error("No free slot") + return 1 + if image.hash == image_tlv_sha256: + if already_uploaded: + _LOGGER.error("Both slots have the same image") + return 1 + if image.confirmed: + _LOGGER.error("Image already confirmted") + return 1 + _LOGGER.warning("The same image already uploaded") + already_uploaded = True + + if not already_uploaded: + with open(firmware, "rb") as file: + image = file.read() + file.close() + upload_size = len(image) + progress = ProgressBar() + progress.update(0) + try: + async for offset in smp_client.upload(image): + progress.update(offset / upload_size) + finally: + progress.done() + + _LOGGER.info("Mark image for testing") + r = await smp_client.request(ImageStatesWrite(hash=image_tlv_sha256), 1.0) + + if error(r): + _LOGGER.error(r) + return 1 + + # give a chance to execute completion callback + time.sleep(1) + _LOGGER.info("Reset") + r = await smp_client.request(ResetWrite(), 1.0) + + if error(r): + _LOGGER.error(r) + return 1 + + return 0 diff --git a/requirements.txt b/requirements.txt index 0cbe5e7265..b860fc3104 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,7 @@ pyparsing >= 3.0 # For autocompletion argcomplete>=2.0.0 + +# for mcumgr +rich==13.7.0 +smpclient==3.2.0