"""Python 3 script to automatically generate C++ classes for ESPHome's native API.

It's pretty crappy spaghetti code, but it works.
"""

import re
from pathlib import Path
from textwrap import dedent
from subprocess import call

# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor

cwd = Path(__file__).parent
root = cwd.parent.parent / 'esphome' / 'components' / 'api'
prot = cwd / 'api.protoc'
call(['protoc', '-o', prot, '-I', root, 'api.proto'])
content = prot.read_bytes()

d = descriptor.FileDescriptorSet.FromString(content)


def indent_list(text, padding=u'  '):
    return [padding + line for line in text.splitlines()]


def indent(text, padding=u'  '):
    return u'\n'.join(indent_list(text, padding))


def camel_to_snake(name):
    # https://stackoverflow.com/a/1176023
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


class TypeInfo():
    def __init__(self, field):
        self._field = field

    @property
    def default_value(self):
        return ''

    @property
    def name(self):
        return self._field.name

    @property
    def arg_name(self):
        return self.name

    @property
    def field_name(self):
        return self.name

    @property
    def number(self):
        return self._field.number

    @property
    def repeated(self):
        return self._field.label == 3

    @property
    def cpp_type(self):
        raise NotImplementedError

    @property
    def reference_type(self):
        return f'{self.cpp_type} '

    @property
    def const_reference_type(self):
        return f'{self.cpp_type} '

    @property
    def public_content(self) -> str:
        return [self.class_member]

    @property
    def protected_content(self) -> str:
        return []

    @property
    def class_member(self) -> str:
        return f'{self.cpp_type} {self.field_name}{{{self.default_value}}};  // NOLINT'

    @property
    def decode_varint_content(self) -> str:
        content = self.decode_varint
        if content is None:
            return None
        return dedent(f'''\
        case {self.number}: {{
          this->{self.field_name} = {content};
          return true;
        }}''')

    decode_varint = None

    @property
    def decode_length_content(self) -> str:
        content = self.decode_length
        if content is None:
            return None
        return dedent(f'''\
        case {self.number}: {{
          this->{self.field_name} = {content};
          return true;
        }}''')

    decode_length = None

    @property
    def decode_32bit_content(self) -> str:
        content = self.decode_32bit
        if content is None:
            return None
        return dedent(f'''\
        case {self.number}: {{
          this->{self.field_name} = {content};
          return true;
        }}''')

    decode_32bit = None

    @property
    def decode_64bit_content(self) -> str:
        content = self.decode_64bit
        if content is None:
            return None
        return dedent(f'''\
        case {self.number}: {{
          this->{self.field_name} = {content};
          return true;
        }}''')

    decode_64bit = None

    @property
    def encode_content(self):
        return f'buffer.{self.encode_func}({self.number}, this->{self.field_name});'

    encode_func = None

    @property
    def dump_content(self):
        o = f'out.append("  {self.name}: ");\n'
        o += self.dump(f'this->{self.field_name}') + '\n'
        o += f'out.append("\\n");\n'
        return o

    dump = None


TYPE_INFO = {}


def register_type(name):
    def func(value):
        TYPE_INFO[name] = value
        return value

    return func


@register_type(1)
class DoubleType(TypeInfo):
    cpp_type = 'double'
    default_value = '0.0'
    decode_64bit = 'value.as_double()'
    encode_func = 'encode_double'

    def dump(self, name):
        o = f'sprintf(buffer, "%g", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(2)
class FloatType(TypeInfo):
    cpp_type = 'float'
    default_value = '0.0f'
    decode_32bit = 'value.as_float()'
    encode_func = 'encode_float'

    def dump(self, name):
        o = f'sprintf(buffer, "%g", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(3)
class Int64Type(TypeInfo):
    cpp_type = 'int64_t'
    default_value = '0'
    decode_varint = 'value.as_int64()'
    encode_func = 'encode_int64'

    def dump(self, name):
        o = f'sprintf(buffer, "%ll", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(4)
class UInt64Type(TypeInfo):
    cpp_type = 'uint64_t'
    default_value = '0'
    decode_varint = 'value.as_uint64()'
    encode_func = 'encode_uint64'

    def dump(self, name):
        o = f'sprintf(buffer, "%ull", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(5)
class Int32Type(TypeInfo):
    cpp_type = 'int32_t'
    default_value = '0'
    decode_varint = 'value.as_int32()'
    encode_func = 'encode_int32'

    def dump(self, name):
        o = f'sprintf(buffer, "%d", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(6)
class Fixed64Type(TypeInfo):
    cpp_type = 'uint64_t'
    default_value = '0'
    decode_64bit = 'value.as_fixed64()'
    encode_func = 'encode_fixed64'

    def dump(self, name):
        o = f'sprintf(buffer, "%ull", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(7)
class Fixed32Type(TypeInfo):
    cpp_type = 'uint32_t'
    default_value = '0'
    decode_32bit = 'value.as_fixed32()'
    encode_func = 'encode_fixed32'

    def dump(self, name):
        o = f'sprintf(buffer, "%u", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(8)
class BoolType(TypeInfo):
    cpp_type = 'bool'
    default_value = 'false'
    decode_varint = 'value.as_bool()'
    encode_func = 'encode_bool'

    def dump(self, name):
        o = f'out.append(YESNO({name}));'
        return o


@register_type(9)
class StringType(TypeInfo):
    cpp_type = 'std::string'
    default_value = ''
    reference_type = 'std::string &'
    const_reference_type = 'const std::string &'
    decode_length = 'value.as_string()'
    encode_func = 'encode_string'

    def dump(self, name):
        o = f'out.append("\'").append({name}).append("\'");'
        return o


@register_type(11)
class MessageType(TypeInfo):
    @property
    def cpp_type(self):
        return self._field.type_name[1:]

    default_value = ''

    @property
    def reference_type(self):
        return f'{self.cpp_type} &'

    @property
    def const_reference_type(self):
        return f'const {self.cpp_type} &'

    @property
    def encode_func(self):
        return f'encode_message<{self.cpp_type}>'

    @property
    def decode_length(self):
        return f'value.as_message<{self.cpp_type}>()'

    def dump(self, name):
        o = f'{name}.dump_to(out);'
        return o


@register_type(12)
class BytesType(TypeInfo):
    cpp_type = 'std::string'
    default_value = ''
    reference_type = 'std::string &'
    const_reference_type = 'const std::string &'
    decode_length = 'value.as_string()'
    encode_func = 'encode_string'

    def dump(self, name):
        o = f'out.append("\'").append({name}).append("\'");'
        return o


@register_type(13)
class UInt32Type(TypeInfo):
    cpp_type = 'uint32_t'
    default_value = '0'
    decode_varint = 'value.as_uint32()'
    encode_func = 'encode_uint32'

    def dump(self, name):
        o = f'sprintf(buffer, "%u", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(14)
class EnumType(TypeInfo):
    @property
    def cpp_type(self):
        return self._field.type_name[1:]

    @property
    def decode_varint(self):
        return f'value.as_enum<{self.cpp_type}>()'

    default_value = ''

    @property
    def encode_func(self):
        return f'encode_enum<{self.cpp_type}>'

    def dump(self, name):
        o = f'out.append(proto_enum_to_string<{self.cpp_type}>({name}));'
        return o


@register_type(15)
class SFixed32Type(TypeInfo):
    cpp_type = 'int32_t'
    default_value = '0'
    decode_32bit = 'value.as_sfixed32()'
    encode_func = 'encode_sfixed32'

    def dump(self, name):
        o = f'sprintf(buffer, "%d", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(16)
class SFixed64Type(TypeInfo):
    cpp_type = 'int64_t'
    default_value = '0'
    decode_64bit = 'value.as_sfixed64()'
    encode_func = 'encode_sfixed64'

    def dump(self, name):
        o = f'sprintf(buffer, "%ll", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(17)
class SInt32Type(TypeInfo):
    cpp_type = 'int32_t'
    default_value = '0'
    decode_varint = 'value.as_sint32()'
    encode_func = 'encode_sint32'

    def dump(self, name):
        o = f'sprintf(buffer, "%d", {name});\n'
        o += f'out.append(buffer);'
        return o


@register_type(18)
class SInt64Type(TypeInfo):
    cpp_type = 'int64_t'
    default_value = '0'
    decode_varint = 'value.as_sint64()'
    encode_func = 'encode_sin64'

    def dump(self):
        o = f'sprintf(buffer, "%ll", {name});\n'
        o += f'out.append(buffer);'
        return o


class RepeatedTypeInfo(TypeInfo):
    def __init__(self, field):
        super(RepeatedTypeInfo, self).__init__(field)
        self._ti = TYPE_INFO[field.type](field)

    @property
    def cpp_type(self):
        return f'std::vector<{self._ti.cpp_type}>'

    @property
    def reference_type(self):
        return f'{self.cpp_type} &'

    @property
    def const_reference_type(self):
        return f'const {self.cpp_type} &'

    @property
    def decode_varint_content(self) -> str:
        content = self._ti.decode_varint
        if content is None:
            return None
        return dedent(f'''\
        case {self.number}: {{
          this->{self.field_name}.push_back({content});
          return true;
        }}''')

    @property
    def decode_length_content(self) -> str:
        content = self._ti.decode_length
        if content is None:
            return None
        return dedent(f'''\
        case {self.number}: {{
          this->{self.field_name}.push_back({content});
          return true;
        }}''')

    @property
    def decode_32bit_content(self) -> str:
        content = self._ti.decode_32bit
        if content is None:
            return None
        return dedent(f'''\
        case {self.number}: {{
          this->{self.field_name}.push_back({content});
          return true;
        }}''')

    @property
    def decode_64bit_content(self) -> str:
        content = self._ti.decode_64bit
        if content is None:
            return None
        return dedent(f'''\
        case {self.number}: {{
          this->{self.field_name}.push_back({content});
          return true;
        }}''')

    @property
    def _ti_is_bool(self):
        # std::vector is specialized for bool, reference does not work
        return isinstance(self._ti, BoolType)

    @property
    def encode_content(self):
        return f"""\
        for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{
          buffer.{self._ti.encode_func}({self.number}, it, true);
        }}"""

    @property
    def dump_content(self):
        o = f'for (const auto {"" if self._ti_is_bool else "&"}it : this->{self.field_name}) {{\n'
        o += f'  out.append("  {self.name}: ");\n'
        o += indent(self._ti.dump('it')) + '\n'
        o += f'  out.append("\\n");\n'
        o += f'}}\n'
        return o


def build_enum_type(desc):
    out = f"enum {desc.name} : uint32_t {{\n"
    for v in desc.value:
        out += f'  {v.name} = {v.number},\n'
    out += '};\n'

    cpp = f"template<>\n"
    cpp += f"const char *proto_enum_to_string<{desc.name}>({desc.name} value) {{\n"
    cpp += f"  switch (value) {{\n"
    for v in desc.value:
        cpp += f'    case {v.name}: return "{v.name}";\n'
    cpp += f'    default: return "UNKNOWN";\n'
    cpp += f'  }}\n'
    cpp += f'}}\n'

    return out, cpp


def build_message_type(desc):
    public_content = []
    protected_content = []
    decode_varint = []
    decode_length = []
    decode_32bit = []
    decode_64bit = []
    encode = []
    dump = []

    for field in desc.field:
        if field.label == 3:
            ti = RepeatedTypeInfo(field)
        else:
            ti = TYPE_INFO[field.type](field)
        protected_content.extend(ti.protected_content)
        public_content.extend(ti.public_content)
        encode.append(ti.encode_content)

        if ti.decode_varint_content:
            decode_varint.append(ti.decode_varint_content)
        if ti.decode_length_content:
            decode_length.append(ti.decode_length_content)
        if ti.decode_32bit_content:
            decode_32bit.append(ti.decode_32bit_content)
        if ti.decode_64bit_content:
            decode_64bit.append(ti.decode_64bit_content)
        if ti.dump_content:
            dump.append(ti.dump_content)

    cpp = ''
    if decode_varint:
        decode_varint.append('default:\n  return false;')
        o = f'bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n'
        o += '  switch (field_id) {\n'
        o += indent("\n".join(decode_varint), '    ') + '\n'
        o += '  }\n'
        o += '}\n'
        cpp += o
        prot = 'bool decode_varint(uint32_t field_id, ProtoVarInt value) override;'
        protected_content.insert(0, prot)
    if decode_length:
        decode_length.append('default:\n  return false;')
        o = f'bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n'
        o += '  switch (field_id) {\n'
        o += indent("\n".join(decode_length), '    ') + '\n'
        o += '  }\n'
        o += '}\n'
        cpp += o
        prot = 'bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;'
        protected_content.insert(0, prot)
    if decode_32bit:
        decode_32bit.append('default:\n  return false;')
        o = f'bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n'
        o += '  switch (field_id) {\n'
        o += indent("\n".join(decode_32bit), '    ') + '\n'
        o += '  }\n'
        o += '}\n'
        cpp += o
        prot = 'bool decode_32bit(uint32_t field_id, Proto32Bit value) override;'
        protected_content.insert(0, prot)
    if decode_64bit:
        decode_64bit.append('default:\n  return false;')
        o = f'bool {desc.name}::decode_64bit(uint32_t field_id, Proto64bit value) {{\n'
        o += '  switch (field_id) {\n'
        o += indent("\n".join(decode_64bit), '    ') + '\n'
        o += '  }\n'
        o += '}\n'
        cpp += o
        prot = 'bool decode_64bit(uint32_t field_id, Proto64bit value) override;'
        protected_content.insert(0, prot)

    o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{\n"
    o += indent('\n'.join(encode)) + '\n'
    o += '}\n'
    cpp += o
    prot = 'void encode(ProtoWriteBuffer buffer) const override;'
    public_content.append(prot)

    o = f"void {desc.name}::dump_to(std::string &out) const {{\n"
    if dump:
        o += f"  char buffer[64];\n"
        o += f'  out.append("{desc.name} {{\\n");\n'
        o += indent('\n'.join(dump)) + '\n'
        o += f'  out.append("}}");\n'
    else:
        o += f'  out.append("{desc.name} {{}}");\n'
    o += '}\n'
    cpp += o
    prot = 'void dump_to(std::string &out) const override;'
    public_content.append(prot)

    out = f"class {desc.name} : public ProtoMessage {{\n"
    out += ' public:\n'
    out += indent('\n'.join(public_content)) + '\n'
    out += ' protected:\n'
    out += indent('\n'.join(protected_content)) + '\n'
    out += "};\n"
    return out, cpp


file = d.file[0]
content = '''\
#pragma once

#include "proto.h"

namespace esphome {
namespace api {

'''

cpp = '''\
#include "api_pb2.h"
#include "esphome/core/log.h"

namespace esphome {
namespace api {

'''

for enum in file.enum_type:
    s, c = build_enum_type(enum)
    content += s
    cpp += c

mt = file.message_type

for m in mt:
    s, c = build_message_type(m)
    content += s
    cpp += c

content += '''\

}  // namespace api
}  // namespace esphome
'''
cpp += '''\

}  // namespace api
}  // namespace esphome
'''

with open(root / 'api_pb2.h', 'w') as f:
    f.write(content)

with open(root / 'api_pb2.cpp', 'w') as f:
    f.write(cpp)

SOURCE_BOTH = 0
SOURCE_SERVER = 1
SOURCE_CLIENT = 2

RECEIVE_CASES = {}

class_name = 'APIServerConnectionBase'

ifdefs = {}


def get_opt(desc, opt, default=None):
    if not desc.options.HasExtension(opt):
        return default
    return desc.options.Extensions[opt]


def build_service_message_type(mt):
    snake = camel_to_snake(mt.name)
    id_ = get_opt(mt, pb.id)
    if id_ is None:
        return None

    source = get_opt(mt, pb.source, 0)

    ifdef = get_opt(mt, pb.ifdef)
    log = get_opt(mt, pb.log, True)
    nodelay = get_opt(mt, pb.no_delay, False)
    hout = ''
    cout = ''

    if ifdef is not None:
        ifdefs[str(mt.name)] = ifdef
        hout += f'#ifdef {ifdef}\n'
        cout += f'#ifdef {ifdef}\n'

    if source in (SOURCE_BOTH, SOURCE_SERVER):
        # Generate send
        func = f'send_{snake}'
        hout += f'bool {func}(const {mt.name} &msg);\n'
        cout += f'bool {class_name}::{func}(const {mt.name} &msg) {{\n'
        if log:
            cout += f'  ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
        cout += f'  this->set_nodelay({str(nodelay).lower()});\n'
        cout += f'  return this->send_message_<{mt.name}>(msg, {id_});\n'
        cout += f'}}\n'
    if source in (SOURCE_BOTH, SOURCE_CLIENT):
        # Generate receive
        func = f'on_{snake}'
        hout += f'virtual void {func}(const {mt.name} &value){{}};\n'
        case = ''
        if ifdef is not None:
            case += f'#ifdef {ifdef}\n'
        case += f'{mt.name} msg;\n'
        case += f'msg.decode(msg_data, msg_size);\n'
        if log:
            case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
        case += f'this->{func}(msg);\n'
        if ifdef is not None:
            case += f'#endif\n'
        case += 'break;'
        RECEIVE_CASES[id_] = case

    if ifdef is not None:
        hout += f'#endif\n'
        cout += f'#endif\n'

    return hout, cout


hpp = '''\
#pragma once

#include "api_pb2.h"
#include "esphome/core/defines.h"

namespace esphome {
namespace api {

'''

cpp = '''\
#include "api_pb2_service.h"
#include "esphome/core/log.h"

namespace esphome {
namespace api {

static const char *TAG = "api.service";

'''

hpp += f'class {class_name} : public ProtoService {{\n'
hpp += ' public:\n'

for mt in file.message_type:
    obj = build_service_message_type(mt)
    if obj is None:
        continue
    hout, cout = obj
    hpp += indent(hout) + '\n'
    cpp += cout

cases = list(RECEIVE_CASES.items())
cases.sort()
hpp += ' protected:\n'
hpp += f'  bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n'
out = f'bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n'
out += f'  switch(msg_type) {{\n'
for i, case in cases:
    c = f'case {i}: {{\n'
    c += indent(case) + '\n'
    c += f'}}'
    out += indent(c, '    ') + '\n'
out += '    default: \n'
out += '      return false;\n'
out += '  }\n'
out += '  return true;\n'
out += '}\n'
cpp += out
hpp += '};\n'

serv = file.service[0]
class_name = 'APIServerConnection'
hpp += '\n'
hpp += f'class {class_name} : public {class_name}Base {{\n'
hpp += ' public:\n'
hpp_protected = ''
cpp += '\n'

m = serv.method[0]
for m in serv.method:
    func = m.name
    inp = m.input_type[1:]
    ret = m.output_type[1:]
    is_void = ret == 'void'
    snake = camel_to_snake(inp)
    on_func = f'on_{snake}'
    needs_conn = get_opt(m, pb.needs_setup_connection, True)
    needs_auth = get_opt(m, pb.needs_authentication, True)

    ifdef = ifdefs.get(inp, None)

    if ifdef is not None:
        hpp += f'#ifdef {ifdef}\n'
        hpp_protected += f'#ifdef {ifdef}\n'
        cpp += f'#ifdef {ifdef}\n'

    hpp_protected += f'  void {on_func}(const {inp} &msg) override;\n'
    hpp += f'  virtual {ret} {func}(const {inp} &msg) = 0;\n'
    cpp += f'void {class_name}::{on_func}(const {inp} &msg) {{\n'
    body = ''
    if needs_conn:
        body += 'if (!this->is_connection_setup()) {\n'
        body += '  this->on_no_setup_connection();\n'
        body += '  return;\n'
        body += '}\n'
    if needs_auth:
        body += 'if (!this->is_authenticated()) {\n'
        body += '  this->on_unauthenticated_access();\n'
        body += '  return;\n'
        body += '}\n'

    if is_void:
        body += f'this->{func}(msg);\n'
    else:
        body += f'{ret} ret = this->{func}(msg);\n'
        ret_snake = camel_to_snake(ret)
        body += f'if (!this->send_{ret_snake}(ret)) {{\n'
        body += f'  this->on_fatal_error();\n'
        body += '}\n'
    cpp += indent(body) + '\n' + '}\n'

    if ifdef is not None:
        hpp += f'#endif\n'
        hpp_protected += f'#endif\n'
        cpp += f'#endif\n'

hpp += ' protected:\n'
hpp += hpp_protected
hpp += '};\n'

hpp += '''\

}  // namespace api
}  // namespace esphome
'''
cpp += '''\

}  // namespace api
}  // namespace esphome
'''

with open(root / 'api_pb2_service.h', 'w') as f:
    f.write(hpp)

with open(root / 'api_pb2_service.cpp', 'w') as f:
    f.write(cpp)

prot.unlink()