Import BT UUID validations

This commit is contained in:
Rapsssito 2024-10-19 10:56:57 +02:00
parent d20011d0b2
commit fcb20b17f7
3 changed files with 47 additions and 61 deletions

View file

@ -1,3 +1,5 @@
import re
from esphome import automation from esphome import automation
import esphome.codegen as cg import esphome.codegen as cg
from esphome.components.esp32 import add_idf_sdkconfig_option, const, get_esp32_variant from esphome.components.esp32 import add_idf_sdkconfig_option, const, get_esp32_variant
@ -61,6 +63,43 @@ CONFIG_SCHEMA = cv.Schema(
).extend(cv.COMPONENT_SCHEMA) ).extend(cv.COMPONENT_SCHEMA)
bt_uuid16_format = "XXXX"
bt_uuid32_format = "XXXXXXXX"
bt_uuid128_format = "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX"
def bt_uuid(value):
in_value = cv.string_strict(value)
value = in_value.upper()
if len(value) == len(bt_uuid16_format):
pattern = re.compile("^[A-F|0-9]{4,}$")
if not pattern.match(value):
raise cv.Invalid(
f"Invalid hexadecimal value for 16 bit UUID format: '{in_value}'"
)
return value
if len(value) == len(bt_uuid32_format):
pattern = re.compile("^[A-F|0-9]{8,}$")
if not pattern.match(value):
raise cv.Invalid(
f"Invalid hexadecimal value for 32 bit UUID format: '{in_value}'"
)
return value
if len(value) == len(bt_uuid128_format):
pattern = re.compile(
"^[A-F|0-9]{8,}-[A-F|0-9]{4,}-[A-F|0-9]{4,}-[A-F|0-9]{4,}-[A-F|0-9]{12,}$"
)
if not pattern.match(value):
raise cv.Invalid(
f"Invalid hexadecimal value for 128 UUID format: '{in_value}'"
)
return value
raise cv.Invalid(
f"Service UUID must be in 16 bit '{bt_uuid16_format}', 32 bit '{bt_uuid32_format}', or 128 bit '{bt_uuid128_format}' format"
)
def validate_variant(_): def validate_variant(_):
variant = get_esp32_variant() variant = get_esp32_variant()
if variant in NO_BLUETOOTH_VARIANTS: if variant in NO_BLUETOOTH_VARIANTS:

View file

@ -2,6 +2,7 @@ from esphome import automation
import esphome.codegen as cg import esphome.codegen as cg
from esphome.components import esp32_ble from esphome.components import esp32_ble
from esphome.components.esp32 import add_idf_sdkconfig_option from esphome.components.esp32 import add_idf_sdkconfig_option
from esphome.components.esp32_ble import bt_uuid, bt_uuid16_format, bt_uuid32_format, bt_uuid128_format
import esphome.config_validation as cv import esphome.config_validation as cv
from esphome.const import ( from esphome.const import (
CONF_ID, CONF_ID,
@ -70,12 +71,6 @@ PROPERTY_MAP = {
} }
def validate_uuid(value):
if len(value) != 36:
raise cv.Invalid("UUID must be exactly 36 characters long")
return value
def validate_on_write(char_config): def validate_on_write(char_config):
if CONF_ON_WRITE in char_config: if CONF_ON_WRITE in char_config:
if not char_config[CONF_WRITE] and not char_config[CONF_WRITE_NO_RESPONSE]: if not char_config[CONF_WRITE] and not char_config[CONF_WRITE_NO_RESPONSE]:
@ -113,8 +108,6 @@ def validate_notify_action(action_char_id):
return action_char_id return action_char_id
UUID_SCHEMA = cv.Any(cv.All(cv.string, validate_uuid), cv.uint32_t)
DESCRIPTOR_VALUE_SCHEMA = cv.Any( DESCRIPTOR_VALUE_SCHEMA = cv.Any(
cv.boolean, cv.boolean,
cv.float_, cv.float_,
@ -140,7 +133,7 @@ CHARACTERISTIC_VALUE_SCHEMA = cv.Any(
DESCRIPTOR_SCHEMA = cv.Schema( DESCRIPTOR_SCHEMA = cv.Schema(
{ {
cv.GenerateID(): cv.declare_id(BLEDescriptor), cv.GenerateID(): cv.declare_id(BLEDescriptor),
cv.Required(CONF_UUID): UUID_SCHEMA, cv.Required(CONF_UUID): bt_uuid,
cv.Required(CONF_VALUE): DESCRIPTOR_VALUE_SCHEMA, cv.Required(CONF_VALUE): DESCRIPTOR_VALUE_SCHEMA,
} }
) )
@ -148,7 +141,7 @@ DESCRIPTOR_SCHEMA = cv.Schema(
SERVICE_CHARACTERISTIC_SCHEMA = cv.Schema( SERVICE_CHARACTERISTIC_SCHEMA = cv.Schema(
{ {
cv.GenerateID(): cv.declare_id(BLECharacteristic), cv.GenerateID(): cv.declare_id(BLECharacteristic),
cv.Required(CONF_UUID): UUID_SCHEMA, cv.Required(CONF_UUID): bt_uuid,
cv.Optional(CONF_WRITE_NO_RESPONSE, default=False): cv.boolean, cv.Optional(CONF_WRITE_NO_RESPONSE, default=False): cv.boolean,
cv.Optional(CONF_VALUE): CHARACTERISTIC_VALUE_SCHEMA, cv.Optional(CONF_VALUE): CHARACTERISTIC_VALUE_SCHEMA,
cv.GenerateID(CONF_VALUE_ACTION_ID_): cv.declare_id( cv.GenerateID(CONF_VALUE_ACTION_ID_): cv.declare_id(
@ -165,7 +158,7 @@ SERVICE_CHARACTERISTIC_SCHEMA = cv.Schema(
SERVICE_SCHEMA = cv.Schema( SERVICE_SCHEMA = cv.Schema(
{ {
cv.GenerateID(): cv.declare_id(BLEService), cv.GenerateID(): cv.declare_id(BLEService),
cv.Required(CONF_UUID): UUID_SCHEMA, cv.Required(CONF_UUID): bt_uuid,
cv.Optional(CONF_ADVERTISE, default=False): cv.boolean, cv.Optional(CONF_ADVERTISE, default=False): cv.boolean,
cv.Optional(CONF_CHARACTERISTICS, default=[]): cv.ensure_list( cv.Optional(CONF_CHARACTERISTICS, default=[]): cv.ensure_list(
SERVICE_CHARACTERISTIC_SCHEMA SERVICE_CHARACTERISTIC_SCHEMA
@ -192,14 +185,6 @@ def parse_properties(char_conf):
) )
def parse_uuid(uuid):
# If the UUID is a string, use from_raw
if isinstance(uuid, str):
return ESPBTUUID_ns.from_raw(uuid)
# Otherwise, use from_uint32
return ESPBTUUID_ns.from_uint32(uuid)
def parse_descriptor_value(value): def parse_descriptor_value(value):
# Compute the maximum length of the descriptor value # Compute the maximum length of the descriptor value
# Also parse the value for byte arrays # Also parse the value for byte arrays
@ -288,7 +273,7 @@ async def to_code(config):
service_var = cg.Pvariable( service_var = cg.Pvariable(
service_config[CONF_ID], service_config[CONF_ID],
var.create_service( var.create_service(
parse_uuid(service_config[CONF_UUID]), ESPBTUUID_ns.from_raw(service_config[CONF_UUID]),
service_config[CONF_ADVERTISE], service_config[CONF_ADVERTISE],
num_handles, num_handles,
), ),
@ -297,7 +282,7 @@ async def to_code(config):
char_var = cg.Pvariable( char_var = cg.Pvariable(
char_conf[CONF_ID], char_conf[CONF_ID],
service_var.create_characteristic( service_var.create_characteristic(
parse_uuid(char_conf[CONF_UUID]), ESPBTUUID_ns.from_raw(char_conf[CONF_UUID]),
parse_properties(char_conf), parse_properties(char_conf),
), ),
) )
@ -326,7 +311,7 @@ async def to_code(config):
) )
desc_var = cg.new_Pvariable( desc_var = cg.new_Pvariable(
descriptor_conf[CONF_ID], descriptor_conf[CONF_ID],
parse_uuid(descriptor_conf[CONF_UUID]), ESPBTUUID_ns.from_raw(descriptor_conf[CONF_UUID]),
max_length, max_length,
) )
if CONF_VALUE in descriptor_conf: if CONF_VALUE in descriptor_conf:

View file

@ -1,9 +1,8 @@
import re
from esphome import automation from esphome import automation
import esphome.codegen as cg import esphome.codegen as cg
from esphome.components import esp32_ble from esphome.components import esp32_ble
from esphome.components.esp32 import add_idf_sdkconfig_option from esphome.components.esp32 import add_idf_sdkconfig_option
from esphome.components.esp32_ble import bt_uuid, bt_uuid16_format, bt_uuid32_format, bt_uuid128_format
import esphome.config_validation as cv import esphome.config_validation as cv
from esphome.const import ( from esphome.const import (
CONF_ACTIVE, CONF_ACTIVE,
@ -86,43 +85,6 @@ def validate_scan_parameters(config):
return config return config
bt_uuid16_format = "XXXX"
bt_uuid32_format = "XXXXXXXX"
bt_uuid128_format = "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX"
def bt_uuid(value):
in_value = cv.string_strict(value)
value = in_value.upper()
if len(value) == len(bt_uuid16_format):
pattern = re.compile("^[A-F|0-9]{4,}$")
if not pattern.match(value):
raise cv.Invalid(
f"Invalid hexadecimal value for 16 bit UUID format: '{in_value}'"
)
return value
if len(value) == len(bt_uuid32_format):
pattern = re.compile("^[A-F|0-9]{8,}$")
if not pattern.match(value):
raise cv.Invalid(
f"Invalid hexadecimal value for 32 bit UUID format: '{in_value}'"
)
return value
if len(value) == len(bt_uuid128_format):
pattern = re.compile(
"^[A-F|0-9]{8,}-[A-F|0-9]{4,}-[A-F|0-9]{4,}-[A-F|0-9]{4,}-[A-F|0-9]{12,}$"
)
if not pattern.match(value):
raise cv.Invalid(
f"Invalid hexadecimal value for 128 UUID format: '{in_value}'"
)
return value
raise cv.Invalid(
f"Service UUID must be in 16 bit '{bt_uuid16_format}', 32 bit '{bt_uuid32_format}', or 128 bit '{bt_uuid128_format}' format"
)
def as_hex(value): def as_hex(value):
return cg.RawExpression(f"0x{value}ULL") return cg.RawExpression(f"0x{value}ULL")