mirror of
https://github.com/esphome/esphome.git
synced 2024-12-18 11:34:54 +01:00
977 lines
25 KiB
Python
Executable file
977 lines
25 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
|
|
|
|
It's pretty crappy spaghetti code, but it works.
|
|
|
|
you need to install protobuf-compiler:
|
|
running protoc --version should return
|
|
libprotoc 3.6.1
|
|
|
|
then run this script with python3 and the files
|
|
|
|
esphome/components/api/api_pb2_service.h
|
|
esphome/components/api/api_pb2_service.cpp
|
|
esphome/components/api/api_pb2.h
|
|
esphome/components/api/api_pb2.cpp
|
|
|
|
will be generated, they still need to be formatted
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from subprocess import call
|
|
from textwrap import dedent
|
|
|
|
# Generate with
|
|
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
|
|
import aioesphomeapi.api_options_pb2 as pb
|
|
import google.protobuf.descriptor_pb2 as descriptor
|
|
|
|
FILE_HEADER = """// This file was automatically generated with a tool.
|
|
// See scripts/api_protobuf/api_protobuf.py
|
|
"""
|
|
|
|
|
|
def indent_list(text, padding=" "):
|
|
lines = []
|
|
for line in text.splitlines():
|
|
if line == "":
|
|
p = ""
|
|
elif line.startswith("#ifdef") or line.startswith("#endif"):
|
|
p = ""
|
|
else:
|
|
p = padding
|
|
lines.append(p + line)
|
|
return lines
|
|
|
|
|
|
def indent(text, padding=" "):
|
|
return "\n".join(indent_list(text, padding))
|
|
|
|
|
|
def camel_to_snake(name):
|
|
# https://stackoverflow.com/a/1176023
|
|
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
|
|
|
|
class TypeInfo(ABC):
|
|
def __init__(self, field):
|
|
self._field = field
|
|
|
|
@property
|
|
def default_value(self):
|
|
return ""
|
|
|
|
@property
|
|
def name(self):
|
|
return self._field.name
|
|
|
|
@property
|
|
def arg_name(self):
|
|
return self.name
|
|
|
|
@property
|
|
def field_name(self):
|
|
return self.name
|
|
|
|
@property
|
|
def number(self):
|
|
return self._field.number
|
|
|
|
@property
|
|
def repeated(self):
|
|
return self._field.label == 3
|
|
|
|
@property
|
|
def cpp_type(self):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def reference_type(self):
|
|
return f"{self.cpp_type} "
|
|
|
|
@property
|
|
def const_reference_type(self):
|
|
return f"{self.cpp_type} "
|
|
|
|
@property
|
|
def public_content(self) -> str:
|
|
return [self.class_member]
|
|
|
|
@property
|
|
def protected_content(self) -> str:
|
|
return []
|
|
|
|
@property
|
|
def class_member(self) -> str:
|
|
return f"{self.cpp_type} {self.field_name}{{{self.default_value}}};"
|
|
|
|
@property
|
|
def decode_varint_content(self) -> str:
|
|
content = self.decode_varint
|
|
if content is None:
|
|
return None
|
|
return dedent(
|
|
f"""\
|
|
case {self.number}: {{
|
|
this->{self.field_name} = {content};
|
|
return true;
|
|
}}"""
|
|
)
|
|
|
|
decode_varint = None
|
|
|
|
@property
|
|
def decode_length_content(self) -> str:
|
|
content = self.decode_length
|
|
if content is None:
|
|
return None
|
|
return dedent(
|
|
f"""\
|
|
case {self.number}: {{
|
|
this->{self.field_name} = {content};
|
|
return true;
|
|
}}"""
|
|
)
|
|
|
|
decode_length = None
|
|
|
|
@property
|
|
def decode_32bit_content(self) -> str:
|
|
content = self.decode_32bit
|
|
if content is None:
|
|
return None
|
|
return dedent(
|
|
f"""\
|
|
case {self.number}: {{
|
|
this->{self.field_name} = {content};
|
|
return true;
|
|
}}"""
|
|
)
|
|
|
|
decode_32bit = None
|
|
|
|
@property
|
|
def decode_64bit_content(self) -> str:
|
|
content = self.decode_64bit
|
|
if content is None:
|
|
return None
|
|
return dedent(
|
|
f"""\
|
|
case {self.number}: {{
|
|
this->{self.field_name} = {content};
|
|
return true;
|
|
}}"""
|
|
)
|
|
|
|
decode_64bit = None
|
|
|
|
@property
|
|
def encode_content(self):
|
|
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
|
|
|
|
encode_func = None
|
|
|
|
@property
|
|
def dump_content(self):
|
|
o = f'out.append(" {self.name}: ");\n'
|
|
o += self.dump(f"this->{self.field_name}") + "\n"
|
|
o += 'out.append("\\n");\n'
|
|
return o
|
|
|
|
@abstractmethod
|
|
def dump(self, name: str):
|
|
pass
|
|
|
|
|
|
TYPE_INFO = {}
|
|
|
|
|
|
def register_type(name):
|
|
def func(value):
|
|
TYPE_INFO[name] = value
|
|
return value
|
|
|
|
return func
|
|
|
|
|
|
@register_type(1)
|
|
class DoubleType(TypeInfo):
|
|
cpp_type = "double"
|
|
default_value = "0.0"
|
|
decode_64bit = "value.as_double()"
|
|
encode_func = "encode_double"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%g", {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(2)
|
|
class FloatType(TypeInfo):
|
|
cpp_type = "float"
|
|
default_value = "0.0f"
|
|
decode_32bit = "value.as_float()"
|
|
encode_func = "encode_float"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%g", {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(3)
|
|
class Int64Type(TypeInfo):
|
|
cpp_type = "int64_t"
|
|
default_value = "0"
|
|
decode_varint = "value.as_int64()"
|
|
encode_func = "encode_int64"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%lld", {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(4)
|
|
class UInt64Type(TypeInfo):
|
|
cpp_type = "uint64_t"
|
|
default_value = "0"
|
|
decode_varint = "value.as_uint64()"
|
|
encode_func = "encode_uint64"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%llu", {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(5)
|
|
class Int32Type(TypeInfo):
|
|
cpp_type = "int32_t"
|
|
default_value = "0"
|
|
decode_varint = "value.as_int32()"
|
|
encode_func = "encode_int32"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(6)
|
|
class Fixed64Type(TypeInfo):
|
|
cpp_type = "uint64_t"
|
|
default_value = "0"
|
|
decode_64bit = "value.as_fixed64()"
|
|
encode_func = "encode_fixed64"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%llu", {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(7)
|
|
class Fixed32Type(TypeInfo):
|
|
cpp_type = "uint32_t"
|
|
default_value = "0"
|
|
decode_32bit = "value.as_fixed32()"
|
|
encode_func = "encode_fixed32"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(8)
|
|
class BoolType(TypeInfo):
|
|
cpp_type = "bool"
|
|
default_value = "false"
|
|
decode_varint = "value.as_bool()"
|
|
encode_func = "encode_bool"
|
|
|
|
def dump(self, name):
|
|
o = f"out.append(YESNO({name}));"
|
|
return o
|
|
|
|
|
|
@register_type(9)
|
|
class StringType(TypeInfo):
|
|
cpp_type = "std::string"
|
|
default_value = ""
|
|
reference_type = "std::string &"
|
|
const_reference_type = "const std::string &"
|
|
decode_length = "value.as_string()"
|
|
encode_func = "encode_string"
|
|
|
|
def dump(self, name):
|
|
o = f'out.append("\'").append({name}).append("\'");'
|
|
return o
|
|
|
|
|
|
@register_type(11)
|
|
class MessageType(TypeInfo):
|
|
@property
|
|
def cpp_type(self):
|
|
return self._field.type_name[1:]
|
|
|
|
default_value = ""
|
|
|
|
@property
|
|
def reference_type(self):
|
|
return f"{self.cpp_type} &"
|
|
|
|
@property
|
|
def const_reference_type(self):
|
|
return f"const {self.cpp_type} &"
|
|
|
|
@property
|
|
def encode_func(self):
|
|
return f"encode_message<{self.cpp_type}>"
|
|
|
|
@property
|
|
def decode_length(self):
|
|
return f"value.as_message<{self.cpp_type}>()"
|
|
|
|
def dump(self, name):
|
|
o = f"{name}.dump_to(out);"
|
|
return o
|
|
|
|
|
|
@register_type(12)
|
|
class BytesType(TypeInfo):
|
|
cpp_type = "std::string"
|
|
default_value = ""
|
|
reference_type = "std::string &"
|
|
const_reference_type = "const std::string &"
|
|
decode_length = "value.as_string()"
|
|
encode_func = "encode_string"
|
|
|
|
def dump(self, name):
|
|
o = f'out.append("\'").append({name}).append("\'");'
|
|
return o
|
|
|
|
|
|
@register_type(13)
|
|
class UInt32Type(TypeInfo):
|
|
cpp_type = "uint32_t"
|
|
default_value = "0"
|
|
decode_varint = "value.as_uint32()"
|
|
encode_func = "encode_uint32"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(14)
|
|
class EnumType(TypeInfo):
|
|
@property
|
|
def cpp_type(self):
|
|
return f"enums::{self._field.type_name[1:]}"
|
|
|
|
@property
|
|
def decode_varint(self):
|
|
return f"value.as_enum<{self.cpp_type}>()"
|
|
|
|
default_value = ""
|
|
|
|
@property
|
|
def encode_func(self):
|
|
return f"encode_enum<{self.cpp_type}>"
|
|
|
|
def dump(self, name):
|
|
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
|
|
return o
|
|
|
|
|
|
@register_type(15)
|
|
class SFixed32Type(TypeInfo):
|
|
cpp_type = "int32_t"
|
|
default_value = "0"
|
|
decode_32bit = "value.as_sfixed32()"
|
|
encode_func = "encode_sfixed32"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(16)
|
|
class SFixed64Type(TypeInfo):
|
|
cpp_type = "int64_t"
|
|
default_value = "0"
|
|
decode_64bit = "value.as_sfixed64()"
|
|
encode_func = "encode_sfixed64"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%lld", {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(17)
|
|
class SInt32Type(TypeInfo):
|
|
cpp_type = "int32_t"
|
|
default_value = "0"
|
|
decode_varint = "value.as_sint32()"
|
|
encode_func = "encode_sint32"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
@register_type(18)
|
|
class SInt64Type(TypeInfo):
|
|
cpp_type = "int64_t"
|
|
default_value = "0"
|
|
decode_varint = "value.as_sint64()"
|
|
encode_func = "encode_sint64"
|
|
|
|
def dump(self, name):
|
|
o = f'sprintf(buffer, "%lld", {name});\n'
|
|
o += "out.append(buffer);"
|
|
return o
|
|
|
|
|
|
class RepeatedTypeInfo(TypeInfo):
|
|
def __init__(self, field):
|
|
super().__init__(field)
|
|
self._ti = TYPE_INFO[field.type](field)
|
|
|
|
@property
|
|
def cpp_type(self):
|
|
return f"std::vector<{self._ti.cpp_type}>"
|
|
|
|
@property
|
|
def reference_type(self):
|
|
return f"{self.cpp_type} &"
|
|
|
|
@property
|
|
def const_reference_type(self):
|
|
return f"const {self.cpp_type} &"
|
|
|
|
@property
|
|
def decode_varint_content(self) -> str:
|
|
content = self._ti.decode_varint
|
|
if content is None:
|
|
return None
|
|
return dedent(
|
|
f"""\
|
|
case {self.number}: {{
|
|
this->{self.field_name}.push_back({content});
|
|
return true;
|
|
}}"""
|
|
)
|
|
|
|
@property
|
|
def decode_length_content(self) -> str:
|
|
content = self._ti.decode_length
|
|
if content is None:
|
|
return None
|
|
return dedent(
|
|
f"""\
|
|
case {self.number}: {{
|
|
this->{self.field_name}.push_back({content});
|
|
return true;
|
|
}}"""
|
|
)
|
|
|
|
@property
|
|
def decode_32bit_content(self) -> str:
|
|
content = self._ti.decode_32bit
|
|
if content is None:
|
|
return None
|
|
return dedent(
|
|
f"""\
|
|
case {self.number}: {{
|
|
this->{self.field_name}.push_back({content});
|
|
return true;
|
|
}}"""
|
|
)
|
|
|
|
@property
|
|
def decode_64bit_content(self) -> str:
|
|
content = self._ti.decode_64bit
|
|
if content is None:
|
|
return None
|
|
return dedent(
|
|
f"""\
|
|
case {self.number}: {{
|
|
this->{self.field_name}.push_back({content});
|
|
return true;
|
|
}}"""
|
|
)
|
|
|
|
@property
|
|
def _ti_is_bool(self):
|
|
# std::vector is specialized for bool, reference does not work
|
|
return isinstance(self._ti, BoolType)
|
|
|
|
@property
|
|
def encode_content(self):
|
|
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
|
|
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
|
|
o += "}"
|
|
return o
|
|
|
|
@property
|
|
def dump_content(self):
|
|
o = f'for (const auto {"" if self._ti_is_bool else "&"}it : this->{self.field_name}) {{\n'
|
|
o += f' out.append(" {self.name}: ");\n'
|
|
o += indent(self._ti.dump("it")) + "\n"
|
|
o += ' out.append("\\n");\n'
|
|
o += "}\n"
|
|
return o
|
|
|
|
def dump(self, _: str):
|
|
pass
|
|
|
|
|
|
def build_enum_type(desc):
|
|
name = desc.name
|
|
out = f"enum {name} : uint32_t {{\n"
|
|
for v in desc.value:
|
|
out += f" {v.name} = {v.number},\n"
|
|
out += "};\n"
|
|
|
|
cpp = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
|
|
cpp += f"template<> const char *proto_enum_to_string<enums::{name}>(enums::{name} value) {{\n"
|
|
cpp += " switch (value) {\n"
|
|
for v in desc.value:
|
|
cpp += f" case enums::{v.name}:\n"
|
|
cpp += f' return "{v.name}";\n'
|
|
cpp += " default:\n"
|
|
cpp += ' return "UNKNOWN";\n'
|
|
cpp += " }\n"
|
|
cpp += "}\n"
|
|
cpp += "#endif\n"
|
|
|
|
return out, cpp
|
|
|
|
|
|
def build_message_type(desc):
|
|
public_content = []
|
|
protected_content = []
|
|
decode_varint = []
|
|
decode_length = []
|
|
decode_32bit = []
|
|
decode_64bit = []
|
|
encode = []
|
|
dump = []
|
|
|
|
for field in desc.field:
|
|
if field.label == 3:
|
|
ti = RepeatedTypeInfo(field)
|
|
else:
|
|
ti = TYPE_INFO[field.type](field)
|
|
protected_content.extend(ti.protected_content)
|
|
public_content.extend(ti.public_content)
|
|
encode.append(ti.encode_content)
|
|
|
|
if ti.decode_varint_content:
|
|
decode_varint.append(ti.decode_varint_content)
|
|
if ti.decode_length_content:
|
|
decode_length.append(ti.decode_length_content)
|
|
if ti.decode_32bit_content:
|
|
decode_32bit.append(ti.decode_32bit_content)
|
|
if ti.decode_64bit_content:
|
|
decode_64bit.append(ti.decode_64bit_content)
|
|
if ti.dump_content:
|
|
dump.append(ti.dump_content)
|
|
|
|
cpp = ""
|
|
if decode_varint:
|
|
decode_varint.append("default:\n return false;")
|
|
o = f"bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n"
|
|
o += " switch (field_id) {\n"
|
|
o += indent("\n".join(decode_varint), " ") + "\n"
|
|
o += " }\n"
|
|
o += "}\n"
|
|
cpp += o
|
|
prot = "bool decode_varint(uint32_t field_id, ProtoVarInt value) override;"
|
|
protected_content.insert(0, prot)
|
|
if decode_length:
|
|
decode_length.append("default:\n return false;")
|
|
o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n"
|
|
o += " switch (field_id) {\n"
|
|
o += indent("\n".join(decode_length), " ") + "\n"
|
|
o += " }\n"
|
|
o += "}\n"
|
|
cpp += o
|
|
prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;"
|
|
protected_content.insert(0, prot)
|
|
if decode_32bit:
|
|
decode_32bit.append("default:\n return false;")
|
|
o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n"
|
|
o += " switch (field_id) {\n"
|
|
o += indent("\n".join(decode_32bit), " ") + "\n"
|
|
o += " }\n"
|
|
o += "}\n"
|
|
cpp += o
|
|
prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;"
|
|
protected_content.insert(0, prot)
|
|
if decode_64bit:
|
|
decode_64bit.append("default:\n return false;")
|
|
o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n"
|
|
o += " switch (field_id) {\n"
|
|
o += indent("\n".join(decode_64bit), " ") + "\n"
|
|
o += " }\n"
|
|
o += "}\n"
|
|
cpp += o
|
|
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
|
|
protected_content.insert(0, prot)
|
|
|
|
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"
|
|
if encode:
|
|
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
|
|
o += f" {encode[0]} "
|
|
else:
|
|
o += "\n"
|
|
o += indent("\n".join(encode)) + "\n"
|
|
o += "}\n"
|
|
cpp += o
|
|
prot = "void encode(ProtoWriteBuffer buffer) const override;"
|
|
public_content.append(prot)
|
|
|
|
o = f"void {desc.name}::dump_to(std::string &out) const {{"
|
|
if dump:
|
|
if len(dump) == 1 and len(dump[0]) + len(o) + 3 < 120:
|
|
o += f" {dump[0]} "
|
|
else:
|
|
o += "\n"
|
|
o += " __attribute__((unused)) char buffer[64];\n"
|
|
o += f' out.append("{desc.name} {{\\n");\n'
|
|
o += indent("\n".join(dump)) + "\n"
|
|
o += ' out.append("}");\n'
|
|
else:
|
|
o2 = f'out.append("{desc.name} {{}}");'
|
|
if len(o) + len(o2) + 3 < 120:
|
|
o += f" {o2} "
|
|
else:
|
|
o += "\n"
|
|
o += f" {o2}\n"
|
|
o += "}\n"
|
|
cpp += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
|
|
cpp += o
|
|
cpp += "#endif\n"
|
|
prot = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
|
|
prot += "void dump_to(std::string &out) const override;\n"
|
|
prot += "#endif\n"
|
|
public_content.append(prot)
|
|
|
|
out = f"class {desc.name} : public ProtoMessage {{\n"
|
|
out += " public:\n"
|
|
out += indent("\n".join(public_content)) + "\n"
|
|
out += "\n"
|
|
out += " protected:\n"
|
|
out += indent("\n".join(protected_content))
|
|
if len(protected_content) > 0:
|
|
out += "\n"
|
|
out += "};\n"
|
|
return out, cpp
|
|
|
|
|
|
SOURCE_BOTH = 0
|
|
SOURCE_SERVER = 1
|
|
SOURCE_CLIENT = 2
|
|
|
|
RECEIVE_CASES = {}
|
|
|
|
ifdefs = {}
|
|
|
|
|
|
def get_opt(desc, opt, default=None):
|
|
if not desc.options.HasExtension(opt):
|
|
return default
|
|
return desc.options.Extensions[opt]
|
|
|
|
|
|
def build_service_message_type(mt):
|
|
snake = camel_to_snake(mt.name)
|
|
id_ = get_opt(mt, pb.id)
|
|
if id_ is None:
|
|
return None
|
|
|
|
source = get_opt(mt, pb.source, 0)
|
|
|
|
ifdef = get_opt(mt, pb.ifdef)
|
|
log = get_opt(mt, pb.log, True)
|
|
hout = ""
|
|
cout = ""
|
|
|
|
if ifdef is not None:
|
|
ifdefs[str(mt.name)] = ifdef
|
|
hout += f"#ifdef {ifdef}\n"
|
|
cout += f"#ifdef {ifdef}\n"
|
|
|
|
if source in (SOURCE_BOTH, SOURCE_SERVER):
|
|
# Generate send
|
|
func = f"send_{snake}"
|
|
hout += f"bool {func}(const {mt.name} &msg);\n"
|
|
cout += f"bool APIServerConnectionBase::{func}(const {mt.name} &msg) {{\n"
|
|
if log:
|
|
cout += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
|
|
cout += f' ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
|
|
cout += "#endif\n"
|
|
# cout += f' this->set_nodelay({str(nodelay).lower()});\n'
|
|
cout += f" return this->send_message_<{mt.name}>(msg, {id_});\n"
|
|
cout += "}\n"
|
|
if source in (SOURCE_BOTH, SOURCE_CLIENT):
|
|
# Generate receive
|
|
func = f"on_{snake}"
|
|
hout += f"virtual void {func}(const {mt.name} &value){{}};\n"
|
|
case = ""
|
|
if ifdef is not None:
|
|
case += f"#ifdef {ifdef}\n"
|
|
case += f"{mt.name} msg;\n"
|
|
case += "msg.decode(msg_data, msg_size);\n"
|
|
if log:
|
|
case += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
|
|
case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
|
|
case += "#endif\n"
|
|
case += f"this->{func}(msg);\n"
|
|
if ifdef is not None:
|
|
case += "#endif\n"
|
|
case += "break;"
|
|
RECEIVE_CASES[id_] = case
|
|
|
|
if ifdef is not None:
|
|
hout += "#endif\n"
|
|
cout += "#endif\n"
|
|
|
|
return hout, cout
|
|
|
|
|
|
def main():
|
|
cwd = Path(__file__).resolve().parent
|
|
root = cwd.parent.parent / "esphome" / "components" / "api"
|
|
prot_file = root / "api.protoc"
|
|
call(["protoc", "-o", str(prot_file), "-I", str(root), "api.proto"])
|
|
proto_content = prot_file.read_bytes()
|
|
|
|
# pylint: disable-next=no-member
|
|
d = descriptor.FileDescriptorSet.FromString(proto_content)
|
|
|
|
file = d.file[0]
|
|
content = FILE_HEADER
|
|
content += """\
|
|
#pragma once
|
|
|
|
#include "proto.h"
|
|
|
|
namespace esphome {
|
|
namespace api {
|
|
|
|
"""
|
|
|
|
cpp = FILE_HEADER
|
|
cpp += """\
|
|
#include "api_pb2.h"
|
|
#include "esphome/core/log.h"
|
|
|
|
#include <cinttypes>
|
|
|
|
namespace esphome {
|
|
namespace api {
|
|
|
|
"""
|
|
|
|
content += "namespace enums {\n\n"
|
|
|
|
for enum in file.enum_type:
|
|
s, c = build_enum_type(enum)
|
|
content += s
|
|
cpp += c
|
|
|
|
content += "\n} // namespace enums\n\n"
|
|
|
|
mt = file.message_type
|
|
|
|
for m in mt:
|
|
s, c = build_message_type(m)
|
|
content += s
|
|
cpp += c
|
|
|
|
content += """\
|
|
|
|
} // namespace api
|
|
} // namespace esphome
|
|
"""
|
|
cpp += """\
|
|
|
|
} // namespace api
|
|
} // namespace esphome
|
|
"""
|
|
|
|
with open(root / "api_pb2.h", "w", encoding="utf-8") as f:
|
|
f.write(content)
|
|
|
|
with open(root / "api_pb2.cpp", "w", encoding="utf-8") as f:
|
|
f.write(cpp)
|
|
|
|
hpp = FILE_HEADER
|
|
hpp += """\
|
|
#pragma once
|
|
|
|
#include "api_pb2.h"
|
|
#include "esphome/core/defines.h"
|
|
|
|
namespace esphome {
|
|
namespace api {
|
|
|
|
"""
|
|
|
|
cpp = FILE_HEADER
|
|
cpp += """\
|
|
#include "api_pb2_service.h"
|
|
#include "esphome/core/log.h"
|
|
|
|
namespace esphome {
|
|
namespace api {
|
|
|
|
static const char *const TAG = "api.service";
|
|
|
|
"""
|
|
|
|
class_name = "APIServerConnectionBase"
|
|
|
|
hpp += f"class {class_name} : public ProtoService {{\n"
|
|
hpp += " public:\n"
|
|
|
|
for mt in file.message_type:
|
|
obj = build_service_message_type(mt)
|
|
if obj is None:
|
|
continue
|
|
hout, cout = obj
|
|
hpp += indent(hout) + "\n"
|
|
cpp += cout
|
|
|
|
cases = list(RECEIVE_CASES.items())
|
|
cases.sort()
|
|
hpp += " protected:\n"
|
|
hpp += " bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n"
|
|
out = f"bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n"
|
|
out += " switch (msg_type) {\n"
|
|
for i, case in cases:
|
|
c = f"case {i}: {{\n"
|
|
c += indent(case) + "\n"
|
|
c += "}"
|
|
out += indent(c, " ") + "\n"
|
|
out += " default:\n"
|
|
out += " return false;\n"
|
|
out += " }\n"
|
|
out += " return true;\n"
|
|
out += "}\n"
|
|
cpp += out
|
|
hpp += "};\n"
|
|
|
|
serv = file.service[0]
|
|
class_name = "APIServerConnection"
|
|
hpp += "\n"
|
|
hpp += f"class {class_name} : public {class_name}Base {{\n"
|
|
hpp += " public:\n"
|
|
hpp_protected = ""
|
|
cpp += "\n"
|
|
|
|
m = serv.method[0]
|
|
for m in serv.method:
|
|
func = m.name
|
|
inp = m.input_type[1:]
|
|
ret = m.output_type[1:]
|
|
is_void = ret == "void"
|
|
snake = camel_to_snake(inp)
|
|
on_func = f"on_{snake}"
|
|
needs_conn = get_opt(m, pb.needs_setup_connection, True)
|
|
needs_auth = get_opt(m, pb.needs_authentication, True)
|
|
|
|
ifdef = ifdefs.get(inp, None)
|
|
|
|
if ifdef is not None:
|
|
hpp += f"#ifdef {ifdef}\n"
|
|
hpp_protected += f"#ifdef {ifdef}\n"
|
|
cpp += f"#ifdef {ifdef}\n"
|
|
|
|
hpp_protected += f" void {on_func}(const {inp} &msg) override;\n"
|
|
hpp += f" virtual {ret} {func}(const {inp} &msg) = 0;\n"
|
|
cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n"
|
|
body = ""
|
|
if needs_conn:
|
|
body += "if (!this->is_connection_setup()) {\n"
|
|
body += " this->on_no_setup_connection();\n"
|
|
body += " return;\n"
|
|
body += "}\n"
|
|
if needs_auth:
|
|
body += "if (!this->is_authenticated()) {\n"
|
|
body += " this->on_unauthenticated_access();\n"
|
|
body += " return;\n"
|
|
body += "}\n"
|
|
|
|
if is_void:
|
|
body += f"this->{func}(msg);\n"
|
|
else:
|
|
body += f"{ret} ret = this->{func}(msg);\n"
|
|
ret_snake = camel_to_snake(ret)
|
|
body += f"if (!this->send_{ret_snake}(ret)) {{\n"
|
|
body += " this->on_fatal_error();\n"
|
|
body += "}\n"
|
|
cpp += indent(body) + "\n" + "}\n"
|
|
|
|
if ifdef is not None:
|
|
hpp += "#endif\n"
|
|
hpp_protected += "#endif\n"
|
|
cpp += "#endif\n"
|
|
|
|
hpp += " protected:\n"
|
|
hpp += hpp_protected
|
|
hpp += "};\n"
|
|
|
|
hpp += """\
|
|
|
|
} // namespace api
|
|
} // namespace esphome
|
|
"""
|
|
cpp += """\
|
|
|
|
} // namespace api
|
|
} // namespace esphome
|
|
"""
|
|
|
|
with open(root / "api_pb2_service.h", "w", encoding="utf-8") as f:
|
|
f.write(hpp)
|
|
|
|
with open(root / "api_pb2_service.cpp", "w", encoding="utf-8") as f:
|
|
f.write(cpp)
|
|
|
|
prot_file.unlink()
|
|
|
|
try:
|
|
import clang_format
|
|
|
|
def exec_clang_format(path):
|
|
clang_format_path = os.path.join(
|
|
os.path.dirname(clang_format.__file__), "data", "bin", "clang-format"
|
|
)
|
|
call([clang_format_path, "-i", path])
|
|
|
|
exec_clang_format(root / "api_pb2_service.h")
|
|
exec_clang_format(root / "api_pb2_service.cpp")
|
|
exec_clang_format(root / "api_pb2.h")
|
|
exec_clang_format(root / "api_pb2.cpp")
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|