mirror of
https://github.com/esphome/esphome.git
synced 2025-09-01 10:52:19 +01:00
Create Protobuf Plugin for automatically generating native API stubs (#633)
* Create Protobuf Plugin for automatically generating native API stubs * Format * Delete api.proto * Cleanup, use no_delay conditionally * Updates * Update * Lint * Lint * Fixes * Camera * CustomAPIDevice * Fix negative VarInt, Add User-defined services arrays * Home Assistant Event * Fixes * Update custom_api_device.h
This commit is contained in:
168
script/api_protobuf/api_options_pb2.py
Normal file
168
script/api_protobuf/api_options_pb2.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: api_options.proto
|
||||
|
||||
import sys
|
||||
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
|
||||
from google.protobuf.internal import enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from google.protobuf import reflection as _reflection
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor.FileDescriptor(
|
||||
name='api_options.proto',
|
||||
package='',
|
||||
syntax='proto2',
|
||||
serialized_options=None,
|
||||
serialized_pb=_b('\n\x11\x61pi_options.proto\x1a google/protobuf/descriptor.proto\"\x06\n\x04void*F\n\rAPISourceType\x12\x0f\n\x0bSOURCE_BOTH\x10\x00\x12\x11\n\rSOURCE_SERVER\x10\x01\x12\x11\n\rSOURCE_CLIENT\x10\x02:E\n\x16needs_setup_connection\x12\x1e.google.protobuf.MethodOptions\x18\x8e\x08 \x01(\x08:\x04true:C\n\x14needs_authentication\x12\x1e.google.protobuf.MethodOptions\x18\x8f\x08 \x01(\x08:\x04true:/\n\x02id\x12\x1f.google.protobuf.MessageOptions\x18\x8c\x08 \x01(\r:\x01\x30:M\n\x06source\x12\x1f.google.protobuf.MessageOptions\x18\x8d\x08 \x01(\x0e\x32\x0e.APISourceType:\x0bSOURCE_BOTH:/\n\x05ifdef\x12\x1f.google.protobuf.MessageOptions\x18\x8e\x08 \x01(\t:3\n\x03log\x12\x1f.google.protobuf.MessageOptions\x18\x8f\x08 \x01(\x08:\x04true:9\n\x08no_delay\x12\x1f.google.protobuf.MessageOptions\x18\x90\x08 \x01(\x08:\x05\x66\x61lse')
|
||||
,
|
||||
dependencies=[google_dot_protobuf_dot_descriptor__pb2.DESCRIPTOR,])
|
||||
|
||||
_APISOURCETYPE = _descriptor.EnumDescriptor(
|
||||
name='APISourceType',
|
||||
full_name='APISourceType',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
values=[
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='SOURCE_BOTH', index=0, number=0,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='SOURCE_SERVER', index=1, number=1,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='SOURCE_CLIENT', index=2, number=2,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
],
|
||||
containing_type=None,
|
||||
serialized_options=None,
|
||||
serialized_start=63,
|
||||
serialized_end=133,
|
||||
)
|
||||
_sym_db.RegisterEnumDescriptor(_APISOURCETYPE)
|
||||
|
||||
APISourceType = enum_type_wrapper.EnumTypeWrapper(_APISOURCETYPE)
|
||||
SOURCE_BOTH = 0
|
||||
SOURCE_SERVER = 1
|
||||
SOURCE_CLIENT = 2
|
||||
|
||||
NEEDS_SETUP_CONNECTION_FIELD_NUMBER = 1038
|
||||
needs_setup_connection = _descriptor.FieldDescriptor(
|
||||
name='needs_setup_connection', full_name='needs_setup_connection', index=0,
|
||||
number=1038, type=8, cpp_type=7, label=1,
|
||||
has_default_value=True, default_value=True,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=True, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR)
|
||||
NEEDS_AUTHENTICATION_FIELD_NUMBER = 1039
|
||||
needs_authentication = _descriptor.FieldDescriptor(
|
||||
name='needs_authentication', full_name='needs_authentication', index=1,
|
||||
number=1039, type=8, cpp_type=7, label=1,
|
||||
has_default_value=True, default_value=True,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=True, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR)
|
||||
ID_FIELD_NUMBER = 1036
|
||||
id = _descriptor.FieldDescriptor(
|
||||
name='id', full_name='id', index=2,
|
||||
number=1036, type=13, cpp_type=3, label=1,
|
||||
has_default_value=True, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=True, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR)
|
||||
SOURCE_FIELD_NUMBER = 1037
|
||||
source = _descriptor.FieldDescriptor(
|
||||
name='source', full_name='source', index=3,
|
||||
number=1037, type=14, cpp_type=8, label=1,
|
||||
has_default_value=True, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=True, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR)
|
||||
IFDEF_FIELD_NUMBER = 1038
|
||||
ifdef = _descriptor.FieldDescriptor(
|
||||
name='ifdef', full_name='ifdef', index=4,
|
||||
number=1038, type=9, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=_b("").decode('utf-8'),
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=True, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR)
|
||||
LOG_FIELD_NUMBER = 1039
|
||||
log = _descriptor.FieldDescriptor(
|
||||
name='log', full_name='log', index=5,
|
||||
number=1039, type=8, cpp_type=7, label=1,
|
||||
has_default_value=True, default_value=True,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=True, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR)
|
||||
NO_DELAY_FIELD_NUMBER = 1040
|
||||
no_delay = _descriptor.FieldDescriptor(
|
||||
name='no_delay', full_name='no_delay', index=6,
|
||||
number=1040, type=8, cpp_type=7, label=1,
|
||||
has_default_value=True, default_value=False,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=True, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR)
|
||||
|
||||
|
||||
_VOID = _descriptor.Descriptor(
|
||||
name='void',
|
||||
full_name='void',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto2',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=55,
|
||||
serialized_end=61,
|
||||
)
|
||||
|
||||
DESCRIPTOR.message_types_by_name['void'] = _VOID
|
||||
DESCRIPTOR.enum_types_by_name['APISourceType'] = _APISOURCETYPE
|
||||
DESCRIPTOR.extensions_by_name['needs_setup_connection'] = needs_setup_connection
|
||||
DESCRIPTOR.extensions_by_name['needs_authentication'] = needs_authentication
|
||||
DESCRIPTOR.extensions_by_name['id'] = id
|
||||
DESCRIPTOR.extensions_by_name['source'] = source
|
||||
DESCRIPTOR.extensions_by_name['ifdef'] = ifdef
|
||||
DESCRIPTOR.extensions_by_name['log'] = log
|
||||
DESCRIPTOR.extensions_by_name['no_delay'] = no_delay
|
||||
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
||||
|
||||
void = _reflection.GeneratedProtocolMessageType('void', (_message.Message,), dict(
|
||||
DESCRIPTOR = _VOID,
|
||||
__module__ = 'api_options_pb2'
|
||||
# @@protoc_insertion_point(class_scope:void)
|
||||
))
|
||||
_sym_db.RegisterMessage(void)
|
||||
|
||||
google_dot_protobuf_dot_descriptor__pb2.MethodOptions.RegisterExtension(needs_setup_connection)
|
||||
google_dot_protobuf_dot_descriptor__pb2.MethodOptions.RegisterExtension(needs_authentication)
|
||||
google_dot_protobuf_dot_descriptor__pb2.MessageOptions.RegisterExtension(id)
|
||||
source.enum_type = _APISOURCETYPE
|
||||
google_dot_protobuf_dot_descriptor__pb2.MessageOptions.RegisterExtension(source)
|
||||
google_dot_protobuf_dot_descriptor__pb2.MessageOptions.RegisterExtension(ifdef)
|
||||
google_dot_protobuf_dot_descriptor__pb2.MessageOptions.RegisterExtension(log)
|
||||
google_dot_protobuf_dot_descriptor__pb2.MessageOptions.RegisterExtension(no_delay)
|
||||
|
||||
# @@protoc_insertion_point(module_scope)
|
866
script/api_protobuf/api_protobuf.py
Normal file
866
script/api_protobuf/api_protobuf.py
Normal file
@@ -0,0 +1,866 @@
|
||||
"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
|
||||
|
||||
It's pretty crappy spaghetti code, but it works.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from subprocess import call
|
||||
|
||||
# Generate with
|
||||
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
|
||||
import api_options_pb2 as pb
|
||||
import google.protobuf.descriptor_pb2 as descriptor
|
||||
|
||||
cwd = Path(__file__).parent
|
||||
root = cwd.parent.parent / 'esphome' / 'components' / 'api'
|
||||
prot = cwd / 'api.protoc'
|
||||
call(['protoc', '-o', prot, '-I', root, 'api.proto'])
|
||||
content = prot.read_bytes()
|
||||
|
||||
d = descriptor.FileDescriptorSet.FromString(content)
|
||||
|
||||
|
||||
def indent_list(text, padding=u' '):
|
||||
return [padding + line for line in text.splitlines()]
|
||||
|
||||
|
||||
def indent(text, padding=u' '):
|
||||
return u'\n'.join(indent_list(text, padding))
|
||||
|
||||
|
||||
def camel_to_snake(name):
|
||||
# https://stackoverflow.com/a/1176023
|
||||
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
||||
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
||||
|
||||
|
||||
class TypeInfo():
|
||||
def __init__(self, field):
|
||||
self._field = field
|
||||
|
||||
@property
|
||||
def default_value(self):
|
||||
return ''
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._field.name
|
||||
|
||||
@property
|
||||
def arg_name(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def field_name(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def number(self):
|
||||
return self._field.number
|
||||
|
||||
@property
|
||||
def repeated(self):
|
||||
return self._field.label == 3
|
||||
|
||||
@property
|
||||
def cpp_type(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def reference_type(self):
|
||||
return f'{self.cpp_type} '
|
||||
|
||||
@property
|
||||
def const_reference_type(self):
|
||||
return f'{self.cpp_type} '
|
||||
|
||||
@property
|
||||
def public_content(self) -> str:
|
||||
return [self.class_member]
|
||||
|
||||
@property
|
||||
def protected_content(self) -> str:
|
||||
return []
|
||||
|
||||
@property
|
||||
def class_member(self) -> str:
|
||||
return f'{self.cpp_type} {self.field_name}{{{self.default_value}}}; // NOLINT'
|
||||
|
||||
@property
|
||||
def decode_varint_content(self) -> str:
|
||||
content = self.decode_varint
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(f'''\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name} = {content};
|
||||
return true;
|
||||
}}''')
|
||||
|
||||
decode_varint = None
|
||||
|
||||
@property
|
||||
def decode_length_content(self) -> str:
|
||||
content = self.decode_length
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(f'''\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name} = {content};
|
||||
return true;
|
||||
}}''')
|
||||
|
||||
decode_length = None
|
||||
|
||||
@property
|
||||
def decode_32bit_content(self) -> str:
|
||||
content = self.decode_32bit
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(f'''\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name} = {content};
|
||||
return true;
|
||||
}}''')
|
||||
|
||||
decode_32bit = None
|
||||
|
||||
@property
|
||||
def decode_64bit_content(self) -> str:
|
||||
content = self.decode_64bit
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(f'''\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name} = {content};
|
||||
return true;
|
||||
}}''')
|
||||
|
||||
decode_64bit = None
|
||||
|
||||
@property
|
||||
def encode_content(self):
|
||||
return f'buffer.{self.encode_func}({self.number}, this->{self.field_name});'
|
||||
|
||||
encode_func = None
|
||||
|
||||
@property
|
||||
def dump_content(self):
|
||||
o = f'out.append(" {self.name}: ");\n'
|
||||
o += self.dump(f'this->{self.field_name}') + '\n'
|
||||
o += f'out.append("\\n");\n'
|
||||
return o
|
||||
|
||||
dump = None
|
||||
|
||||
|
||||
TYPE_INFO = {}
|
||||
|
||||
|
||||
def register_type(name):
|
||||
def func(value):
|
||||
TYPE_INFO[name] = value
|
||||
return value
|
||||
|
||||
return func
|
||||
|
||||
|
||||
@register_type(1)
|
||||
class DoubleType(TypeInfo):
|
||||
cpp_type = 'double'
|
||||
default_value = '0.0'
|
||||
decode_64bit = 'value.as_double()'
|
||||
encode_func = 'encode_double'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%g", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(2)
|
||||
class FloatType(TypeInfo):
|
||||
cpp_type = 'float'
|
||||
default_value = '0.0f'
|
||||
decode_32bit = 'value.as_float()'
|
||||
encode_func = 'encode_float'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%g", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(3)
|
||||
class Int64Type(TypeInfo):
|
||||
cpp_type = 'int64_t'
|
||||
default_value = '0'
|
||||
decode_varint = 'value.as_int64()'
|
||||
encode_func = 'encode_int64'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%ll", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(4)
|
||||
class UInt64Type(TypeInfo):
|
||||
cpp_type = 'uint64_t'
|
||||
default_value = '0'
|
||||
decode_varint = 'value.as_uint64()'
|
||||
encode_func = 'encode_uint64'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%ull", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(5)
|
||||
class Int32Type(TypeInfo):
|
||||
cpp_type = 'int32_t'
|
||||
default_value = '0'
|
||||
decode_varint = 'value.as_int32()'
|
||||
encode_func = 'encode_int32'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%d", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(6)
|
||||
class Fixed64Type(TypeInfo):
|
||||
cpp_type = 'uint64_t'
|
||||
default_value = '0'
|
||||
decode_64bit = 'value.as_fixed64()'
|
||||
encode_func = 'encode_fixed64'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%ull", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(7)
|
||||
class Fixed32Type(TypeInfo):
|
||||
cpp_type = 'uint32_t'
|
||||
default_value = '0'
|
||||
decode_32bit = 'value.as_fixed32()'
|
||||
encode_func = 'encode_fixed32'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%u", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(8)
|
||||
class BoolType(TypeInfo):
|
||||
cpp_type = 'bool'
|
||||
default_value = 'false'
|
||||
decode_varint = 'value.as_bool()'
|
||||
encode_func = 'encode_bool'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'out.append(YESNO({name}));'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(9)
|
||||
class StringType(TypeInfo):
|
||||
cpp_type = 'std::string'
|
||||
default_value = ''
|
||||
reference_type = 'std::string &'
|
||||
const_reference_type = 'const std::string &'
|
||||
decode_length = 'value.as_string()'
|
||||
encode_func = 'encode_string'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'out.append("\'").append({name}).append("\'");'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(11)
|
||||
class MessageType(TypeInfo):
|
||||
@property
|
||||
def cpp_type(self):
|
||||
return self._field.type_name[1:]
|
||||
|
||||
default_value = ''
|
||||
|
||||
@property
|
||||
def reference_type(self):
|
||||
return f'{self.cpp_type} &'
|
||||
|
||||
@property
|
||||
def const_reference_type(self):
|
||||
return f'const {self.cpp_type} &'
|
||||
|
||||
@property
|
||||
def encode_func(self):
|
||||
return f'encode_message<{self.cpp_type}>'
|
||||
|
||||
@property
|
||||
def decode_length(self):
|
||||
return f'value.as_message<{self.cpp_type}>()'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'{name}.dump_to(out);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(12)
|
||||
class BytesType(TypeInfo):
|
||||
cpp_type = 'std::string'
|
||||
default_value = ''
|
||||
reference_type = 'std::string &'
|
||||
const_reference_type = 'const std::string &'
|
||||
decode_length = 'value.as_string()'
|
||||
encode_func = 'encode_string'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'out.append("\'").append({name}).append("\'");'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(13)
|
||||
class UInt32Type(TypeInfo):
|
||||
cpp_type = 'uint32_t'
|
||||
default_value = '0'
|
||||
decode_varint = 'value.as_uint32()'
|
||||
encode_func = 'encode_uint32'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%u", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(14)
|
||||
class EnumType(TypeInfo):
|
||||
@property
|
||||
def cpp_type(self):
|
||||
return self._field.type_name[1:]
|
||||
|
||||
@property
|
||||
def decode_varint(self):
|
||||
return f'value.as_enum<{self.cpp_type}>()'
|
||||
|
||||
default_value = ''
|
||||
|
||||
@property
|
||||
def encode_func(self):
|
||||
return f'encode_enum<{self.cpp_type}>'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'out.append(proto_enum_to_string<{self.cpp_type}>({name}));'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(15)
|
||||
class SFixed32Type(TypeInfo):
|
||||
cpp_type = 'int32_t'
|
||||
default_value = '0'
|
||||
decode_32bit = 'value.as_sfixed32()'
|
||||
encode_func = 'encode_sfixed32'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%d", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(16)
|
||||
class SFixed64Type(TypeInfo):
|
||||
cpp_type = 'int64_t'
|
||||
default_value = '0'
|
||||
decode_64bit = 'value.as_sfixed64()'
|
||||
encode_func = 'encode_sfixed64'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%ll", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(17)
|
||||
class SInt32Type(TypeInfo):
|
||||
cpp_type = 'int32_t'
|
||||
default_value = '0'
|
||||
decode_varint = 'value.as_sint32()'
|
||||
encode_func = 'encode_sint32'
|
||||
|
||||
def dump(self, name):
|
||||
o = f'sprintf(buffer, "%d", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
@register_type(18)
|
||||
class SInt64Type(TypeInfo):
|
||||
cpp_type = 'int64_t'
|
||||
default_value = '0'
|
||||
decode_varint = 'value.as_sint64()'
|
||||
encode_func = 'encode_sin64'
|
||||
|
||||
def dump(self):
|
||||
o = f'sprintf(buffer, "%ll", {name});\n'
|
||||
o += f'out.append(buffer);'
|
||||
return o
|
||||
|
||||
|
||||
class RepeatedTypeInfo(TypeInfo):
|
||||
def __init__(self, field):
|
||||
super(RepeatedTypeInfo, self).__init__(field)
|
||||
self._ti = TYPE_INFO[field.type](field)
|
||||
|
||||
@property
|
||||
def cpp_type(self):
|
||||
return f'std::vector<{self._ti.cpp_type}>'
|
||||
|
||||
@property
|
||||
def reference_type(self):
|
||||
return f'{self.cpp_type} &'
|
||||
|
||||
@property
|
||||
def const_reference_type(self):
|
||||
return f'const {self.cpp_type} &'
|
||||
|
||||
@property
|
||||
def decode_varint_content(self) -> str:
|
||||
content = self._ti.decode_varint
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(f'''\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name}.push_back({content});
|
||||
return true;
|
||||
}}''')
|
||||
|
||||
@property
|
||||
def decode_length_content(self) -> str:
|
||||
content = self._ti.decode_length
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(f'''\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name}.push_back({content});
|
||||
return true;
|
||||
}}''')
|
||||
|
||||
@property
|
||||
def decode_32bit_content(self) -> str:
|
||||
content = self._ti.decode_32bit
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(f'''\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name}.push_back({content});
|
||||
return true;
|
||||
}}''')
|
||||
|
||||
@property
|
||||
def decode_64bit_content(self) -> str:
|
||||
content = self._ti.decode_64bit
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(f'''\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name}.push_back({content});
|
||||
return true;
|
||||
}}''')
|
||||
|
||||
@property
|
||||
def _ti_is_bool(self):
|
||||
# std::vector is specialized for bool, reference does not work
|
||||
return isinstance(self._ti, BoolType)
|
||||
|
||||
@property
|
||||
def encode_content(self):
|
||||
return f"""\
|
||||
for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{
|
||||
buffer.{self._ti.encode_func}({self.number}, it, true);
|
||||
}}"""
|
||||
|
||||
@property
|
||||
def dump_content(self):
|
||||
o = f'for (const auto {"" if self._ti_is_bool else "&"}it : this->{self.field_name}) {{\n'
|
||||
o += f' out.append(" {self.name}: ");\n'
|
||||
o += indent(self._ti.dump('it')) + '\n'
|
||||
o += f' out.append("\\n");\n'
|
||||
o += f'}}\n'
|
||||
return o
|
||||
|
||||
|
||||
def build_enum_type(desc):
|
||||
out = f"enum {desc.name} : uint32_t {{\n"
|
||||
for v in desc.value:
|
||||
out += f' {v.name} = {v.number},\n'
|
||||
out += '};\n'
|
||||
|
||||
cpp = f"template<>\n"
|
||||
cpp += f"const char *proto_enum_to_string<{desc.name}>({desc.name} value) {{\n"
|
||||
cpp += f" switch (value) {{\n"
|
||||
for v in desc.value:
|
||||
cpp += f' case {v.name}: return "{v.name}";\n'
|
||||
cpp += f' default: return "UNKNOWN";\n'
|
||||
cpp += f' }}\n'
|
||||
cpp += f'}}\n'
|
||||
|
||||
return out, cpp
|
||||
|
||||
|
||||
def build_message_type(desc):
|
||||
public_content = []
|
||||
protected_content = []
|
||||
decode_varint = []
|
||||
decode_length = []
|
||||
decode_32bit = []
|
||||
decode_64bit = []
|
||||
encode = []
|
||||
dump = []
|
||||
|
||||
for field in desc.field:
|
||||
if field.label == 3:
|
||||
ti = RepeatedTypeInfo(field)
|
||||
else:
|
||||
ti = TYPE_INFO[field.type](field)
|
||||
protected_content.extend(ti.protected_content)
|
||||
public_content.extend(ti.public_content)
|
||||
encode.append(ti.encode_content)
|
||||
|
||||
if ti.decode_varint_content:
|
||||
decode_varint.append(ti.decode_varint_content)
|
||||
if ti.decode_length_content:
|
||||
decode_length.append(ti.decode_length_content)
|
||||
if ti.decode_32bit_content:
|
||||
decode_32bit.append(ti.decode_32bit_content)
|
||||
if ti.decode_64bit_content:
|
||||
decode_64bit.append(ti.decode_64bit_content)
|
||||
if ti.dump_content:
|
||||
dump.append(ti.dump_content)
|
||||
|
||||
cpp = ''
|
||||
if decode_varint:
|
||||
decode_varint.append('default:\n return false;')
|
||||
o = f'bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n'
|
||||
o += ' switch (field_id) {\n'
|
||||
o += indent("\n".join(decode_varint), ' ') + '\n'
|
||||
o += ' }\n'
|
||||
o += '}\n'
|
||||
cpp += o
|
||||
prot = 'bool decode_varint(uint32_t field_id, ProtoVarInt value) override;'
|
||||
protected_content.insert(0, prot)
|
||||
if decode_length:
|
||||
decode_length.append('default:\n return false;')
|
||||
o = f'bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n'
|
||||
o += ' switch (field_id) {\n'
|
||||
o += indent("\n".join(decode_length), ' ') + '\n'
|
||||
o += ' }\n'
|
||||
o += '}\n'
|
||||
cpp += o
|
||||
prot = 'bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;'
|
||||
protected_content.insert(0, prot)
|
||||
if decode_32bit:
|
||||
decode_32bit.append('default:\n return false;')
|
||||
o = f'bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n'
|
||||
o += ' switch (field_id) {\n'
|
||||
o += indent("\n".join(decode_32bit), ' ') + '\n'
|
||||
o += ' }\n'
|
||||
o += '}\n'
|
||||
cpp += o
|
||||
prot = 'bool decode_32bit(uint32_t field_id, Proto32Bit value) override;'
|
||||
protected_content.insert(0, prot)
|
||||
if decode_64bit:
|
||||
decode_64bit.append('default:\n return false;')
|
||||
o = f'bool {desc.name}::decode_64bit(uint32_t field_id, Proto64bit value) {{\n'
|
||||
o += ' switch (field_id) {\n'
|
||||
o += indent("\n".join(decode_64bit), ' ') + '\n'
|
||||
o += ' }\n'
|
||||
o += '}\n'
|
||||
cpp += o
|
||||
prot = 'bool decode_64bit(uint32_t field_id, Proto64bit value) override;'
|
||||
protected_content.insert(0, prot)
|
||||
|
||||
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{\n"
|
||||
o += indent('\n'.join(encode)) + '\n'
|
||||
o += '}\n'
|
||||
cpp += o
|
||||
prot = 'void encode(ProtoWriteBuffer buffer) const override;'
|
||||
public_content.append(prot)
|
||||
|
||||
o = f"void {desc.name}::dump_to(std::string &out) const {{\n"
|
||||
if dump:
|
||||
o += f" char buffer[64];\n"
|
||||
o += f' out.append("{desc.name} {{\\n");\n'
|
||||
o += indent('\n'.join(dump)) + '\n'
|
||||
o += f' out.append("}}");\n'
|
||||
else:
|
||||
o += f' out.append("{desc.name} {{}}");\n'
|
||||
o += '}\n'
|
||||
cpp += o
|
||||
prot = 'void dump_to(std::string &out) const override;'
|
||||
public_content.append(prot)
|
||||
|
||||
out = f"class {desc.name} : public ProtoMessage {{\n"
|
||||
out += ' public:\n'
|
||||
out += indent('\n'.join(public_content)) + '\n'
|
||||
out += ' protected:\n'
|
||||
out += indent('\n'.join(protected_content)) + '\n'
|
||||
out += "};\n"
|
||||
return out, cpp
|
||||
|
||||
|
||||
file = d.file[0]
|
||||
content = '''\
|
||||
#pragma once
|
||||
|
||||
#include "proto.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
'''
|
||||
|
||||
cpp = '''\
|
||||
#include "api_pb2.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
'''
|
||||
|
||||
for enum in file.enum_type:
|
||||
s, c = build_enum_type(enum)
|
||||
content += s
|
||||
cpp += c
|
||||
|
||||
mt = file.message_type
|
||||
|
||||
for m in mt:
|
||||
s, c = build_message_type(m)
|
||||
content += s
|
||||
cpp += c
|
||||
|
||||
content += '''\
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
'''
|
||||
cpp += '''\
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
'''
|
||||
|
||||
with open(root / 'api_pb2.h', 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
with open(root / 'api_pb2.cpp', 'w') as f:
|
||||
f.write(cpp)
|
||||
|
||||
SOURCE_BOTH = 0
|
||||
SOURCE_SERVER = 1
|
||||
SOURCE_CLIENT = 2
|
||||
|
||||
RECEIVE_CASES = {}
|
||||
|
||||
class_name = 'APIServerConnectionBase'
|
||||
|
||||
ifdefs = {}
|
||||
|
||||
|
||||
def get_opt(desc, opt, default=None):
|
||||
if not desc.options.HasExtension(opt):
|
||||
return default
|
||||
return desc.options.Extensions[opt]
|
||||
|
||||
|
||||
def build_service_message_type(mt):
|
||||
snake = camel_to_snake(mt.name)
|
||||
id_ = get_opt(mt, pb.id)
|
||||
if id_ is None:
|
||||
return None
|
||||
|
||||
source = get_opt(mt, pb.source, 0)
|
||||
|
||||
ifdef = get_opt(mt, pb.ifdef)
|
||||
log = get_opt(mt, pb.log, True)
|
||||
nodelay = get_opt(mt, pb.no_delay, False)
|
||||
hout = ''
|
||||
cout = ''
|
||||
|
||||
if ifdef is not None:
|
||||
ifdefs[str(mt.name)] = ifdef
|
||||
hout += f'#ifdef {ifdef}\n'
|
||||
cout += f'#ifdef {ifdef}\n'
|
||||
|
||||
if source in (SOURCE_BOTH, SOURCE_SERVER):
|
||||
# Generate send
|
||||
func = f'send_{snake}'
|
||||
hout += f'bool {func}(const {mt.name} &msg);\n'
|
||||
cout += f'bool {class_name}::{func}(const {mt.name} &msg) {{\n'
|
||||
if log:
|
||||
cout += f' ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
|
||||
cout += f' this->set_nodelay({str(nodelay).lower()});\n'
|
||||
cout += f' return this->send_message_<{mt.name}>(msg, {id_});\n'
|
||||
cout += f'}}\n'
|
||||
if source in (SOURCE_BOTH, SOURCE_CLIENT):
|
||||
# Generate receive
|
||||
func = f'on_{snake}'
|
||||
hout += f'virtual void {func}(const {mt.name} &value){{}};\n'
|
||||
case = ''
|
||||
if ifdef is not None:
|
||||
case += f'#ifdef {ifdef}\n'
|
||||
case += f'{mt.name} msg;\n'
|
||||
case += f'msg.decode(msg_data, msg_size);\n'
|
||||
if log:
|
||||
case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
|
||||
case += f'this->{func}(msg);\n'
|
||||
if ifdef is not None:
|
||||
case += f'#endif\n'
|
||||
case += 'break;'
|
||||
RECEIVE_CASES[id_] = case
|
||||
|
||||
if ifdef is not None:
|
||||
hout += f'#endif\n'
|
||||
cout += f'#endif\n'
|
||||
|
||||
return hout, cout
|
||||
|
||||
|
||||
hpp = '''\
|
||||
#pragma once
|
||||
|
||||
#include "api_pb2.h"
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
'''
|
||||
|
||||
cpp = '''\
|
||||
#include "api_pb2_service.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
static const char *TAG = "api.service";
|
||||
|
||||
'''
|
||||
|
||||
hpp += f'class {class_name} : public ProtoService {{\n'
|
||||
hpp += ' public:\n'
|
||||
|
||||
for mt in file.message_type:
|
||||
obj = build_service_message_type(mt)
|
||||
if obj is None:
|
||||
continue
|
||||
hout, cout = obj
|
||||
hpp += indent(hout) + '\n'
|
||||
cpp += cout
|
||||
|
||||
cases = list(RECEIVE_CASES.items())
|
||||
cases.sort()
|
||||
hpp += ' protected:\n'
|
||||
hpp += f' bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n'
|
||||
out = f'bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n'
|
||||
out += f' switch(msg_type) {{\n'
|
||||
for i, case in cases:
|
||||
c = f'case {i}: {{\n'
|
||||
c += indent(case) + '\n'
|
||||
c += f'}}'
|
||||
out += indent(c, ' ') + '\n'
|
||||
out += ' default: \n'
|
||||
out += ' return false;\n'
|
||||
out += ' }\n'
|
||||
out += ' return true;\n'
|
||||
out += '}\n'
|
||||
cpp += out
|
||||
hpp += '};\n'
|
||||
|
||||
serv = file.service[0]
|
||||
class_name = 'APIServerConnection'
|
||||
hpp += '\n'
|
||||
hpp += f'class {class_name} : public {class_name}Base {{\n'
|
||||
hpp += ' public:\n'
|
||||
hpp_protected = ''
|
||||
cpp += '\n'
|
||||
|
||||
m = serv.method[0]
|
||||
for m in serv.method:
|
||||
func = m.name
|
||||
inp = m.input_type[1:]
|
||||
ret = m.output_type[1:]
|
||||
is_void = ret == 'void'
|
||||
snake = camel_to_snake(inp)
|
||||
on_func = f'on_{snake}'
|
||||
needs_conn = get_opt(m, pb.needs_setup_connection, True)
|
||||
needs_auth = get_opt(m, pb.needs_authentication, True)
|
||||
|
||||
ifdef = ifdefs.get(inp, None)
|
||||
|
||||
if ifdef is not None:
|
||||
hpp += f'#ifdef {ifdef}\n'
|
||||
hpp_protected += f'#ifdef {ifdef}\n'
|
||||
cpp += f'#ifdef {ifdef}\n'
|
||||
|
||||
hpp_protected += f' void {on_func}(const {inp} &msg) override;\n'
|
||||
hpp += f' virtual {ret} {func}(const {inp} &msg) = 0;\n'
|
||||
cpp += f'void {class_name}::{on_func}(const {inp} &msg) {{\n'
|
||||
body = ''
|
||||
if needs_conn:
|
||||
body += 'if (!this->is_connection_setup()) {\n'
|
||||
body += ' this->on_no_setup_connection();\n'
|
||||
body += ' return;\n'
|
||||
body += '}\n'
|
||||
if needs_auth:
|
||||
body += 'if (!this->is_authenticated()) {\n'
|
||||
body += ' this->on_unauthenticated_access();\n'
|
||||
body += ' return;\n'
|
||||
body += '}\n'
|
||||
|
||||
if is_void:
|
||||
body += f'this->{func}(msg);\n'
|
||||
else:
|
||||
body += f'{ret} ret = this->{func}(msg);\n'
|
||||
ret_snake = camel_to_snake(ret)
|
||||
body += f'if (!this->send_{ret_snake}(ret)) {{\n'
|
||||
body += f' this->on_fatal_error();\n'
|
||||
body += '}\n'
|
||||
cpp += indent(body) + '\n' + '}\n'
|
||||
|
||||
if ifdef is not None:
|
||||
hpp += f'#endif\n'
|
||||
hpp_protected += f'#endif\n'
|
||||
cpp += f'#endif\n'
|
||||
|
||||
hpp += ' protected:\n'
|
||||
hpp += hpp_protected
|
||||
hpp += '};\n'
|
||||
|
||||
hpp += '''\
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
'''
|
||||
cpp += '''\
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
'''
|
||||
|
||||
with open(root / 'api_pb2_service.h', 'w') as f:
|
||||
f.write(hpp)
|
||||
|
||||
with open(root / 'api_pb2_service.cpp', 'w') as f:
|
||||
f.write(cpp)
|
||||
|
||||
prot.unlink()
|
@@ -28,6 +28,7 @@ EXECUTABLE_BIT = {
|
||||
s[3].strip(): int(s[0]) for s in lines
|
||||
}
|
||||
files = [s[3].strip() for s in lines]
|
||||
files = list(filter(os.path.exists, files))
|
||||
files.sort()
|
||||
|
||||
file_types = ('.h', '.c', '.cpp', '.tcc', '.yaml', '.yml', '.ini', '.txt', '.ico', '.svg',
|
||||
|
Reference in New Issue
Block a user