1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-11 07:42:26 +01:00

Merge remote-tracking branch 'upstream/dev' into proto_field_ifdefs

This commit is contained in:
J. Nick Koston
2025-07-15 15:19:26 -10:00
197 changed files with 5145 additions and 3199 deletions

View File

@@ -8,7 +8,6 @@ from pathlib import Path
import re
from subprocess import call
import sys
from textwrap import dedent
from typing import Any
import aioesphomeapi.api_options_pb2 as pb
@@ -181,13 +180,7 @@ class TypeInfo(ABC):
content = self.decode_varint
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}"""
)
return f"case {self.number}: this->{self.field_name} = {content}; break;"
decode_varint = None
@@ -196,13 +189,7 @@ class TypeInfo(ABC):
content = self.decode_length
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}"""
)
return f"case {self.number}: this->{self.field_name} = {content}; break;"
decode_length = None
@@ -211,13 +198,7 @@ class TypeInfo(ABC):
content = self.decode_32bit
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}"""
)
return f"case {self.number}: this->{self.field_name} = {content}; break;"
decode_32bit = None
@@ -226,13 +207,7 @@ class TypeInfo(ABC):
content = self.decode_64bit
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}"""
)
return f"case {self.number}: this->{self.field_name} = {content}; break;"
decode_64bit = None
@@ -337,6 +312,37 @@ class TypeInfo(ABC):
TYPE_INFO: dict[int, TypeInfo] = {}
# Unsupported 64-bit types that would add overhead for embedded systems
# TYPE_DOUBLE = 1, TYPE_FIXED64 = 6, TYPE_SFIXED64 = 16, TYPE_SINT64 = 18
UNSUPPORTED_TYPES = {1: "double", 6: "fixed64", 16: "sfixed64", 18: "sint64"}
def validate_field_type(field_type: int, field_name: str = "") -> None:
"""Validate that the field type is supported by ESPHome API.
Raises ValueError for unsupported 64-bit types.
"""
if field_type in UNSUPPORTED_TYPES:
type_name = UNSUPPORTED_TYPES[field_type]
field_info = f" (field: {field_name})" if field_name else ""
raise ValueError(
f"64-bit type '{type_name}'{field_info} is not supported by ESPHome API. "
"These types add significant overhead for embedded systems. "
"If you need 64-bit support, please add the necessary encoding/decoding "
"functions to proto.h/proto.cpp first."
)
def get_type_info_for_field(field: descriptor.FieldDescriptorProto) -> TypeInfo:
"""Get the appropriate TypeInfo for a field, handling repeated fields.
Also validates that the field type is supported.
"""
if field.label == 3: # repeated
return RepeatedTypeInfo(field)
validate_field_type(field.type, field.name)
return TYPE_INFO[field.type](field)
def register_type(name: int):
"""Decorator to register a type with a name and number."""
@@ -573,13 +579,7 @@ class MessageType(TypeInfo):
@property
def decode_length_content(self) -> str:
# Custom decode that doesn't use templates
return dedent(
f"""\
case {self.number}: {{
value.decode_to_message(this->{self.field_name});
return true;
}}"""
)
return f"case {self.number}: value.decode_to_message(this->{self.field_name}); break;"
def dump(self, name: str) -> str:
o = f"{name}.dump_to(out);"
@@ -762,6 +762,7 @@ class SInt64Type(TypeInfo):
class RepeatedTypeInfo(TypeInfo):
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
super().__init__(field)
validate_field_type(field.type, field.name)
self._ti: TypeInfo = TYPE_INFO[field.type](field)
@property
@@ -789,12 +790,8 @@ class RepeatedTypeInfo(TypeInfo):
content = self._ti.decode_varint
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}"""
return (
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
)
@property
@@ -802,22 +799,11 @@ class RepeatedTypeInfo(TypeInfo):
content = self._ti.decode_length
if content is None and isinstance(self._ti, MessageType):
# Special handling for non-template message decoding
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name}.emplace_back();
value.decode_to_message(this->{self.field_name}.back());
return true;
}}"""
)
return f"case {self.number}: this->{self.field_name}.emplace_back(); value.decode_to_message(this->{self.field_name}.back()); break;"
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}"""
return (
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
)
@property
@@ -825,12 +811,8 @@ class RepeatedTypeInfo(TypeInfo):
content = self._ti.decode_32bit
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}"""
return (
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
)
@property
@@ -838,12 +820,8 @@ class RepeatedTypeInfo(TypeInfo):
content = self._ti.decode_64bit
if content is None:
return None
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}"""
return (
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
)
@property
@@ -1049,10 +1027,7 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int:
total_size = 0
for field in desc.field:
if field.label == 3: # repeated
ti = RepeatedTypeInfo(field)
else:
ti = TYPE_INFO[field.type](field)
ti = get_type_info_for_field(field)
# Add estimated size for this field
total_size += ti.get_estimated_size()
@@ -1186,41 +1161,45 @@ def build_message_type(
cpp = ""
if decode_varint:
decode_varint.append("default:\n return false;")
o = f"bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_varint), " ") + "\n"
o += " default: return false;\n"
o += " }\n"
o += " return true;\n"
o += "}\n"
cpp += o
prot = "bool decode_varint(uint32_t field_id, ProtoVarInt value) override;"
protected_content.insert(0, prot)
if decode_length:
decode_length.append("default:\n return false;")
o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_length), " ") + "\n"
o += " default: return false;\n"
o += " }\n"
o += " return true;\n"
o += "}\n"
cpp += o
prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;"
protected_content.insert(0, prot)
if decode_32bit:
decode_32bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_32bit), " ") + "\n"
o += " default: return false;\n"
o += " }\n"
o += " return true;\n"
o += "}\n"
cpp += o
prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;"
protected_content.insert(0, prot)
if decode_64bit:
decode_64bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_64bit), " ") + "\n"
o += " default: return false;\n"
o += " }\n"
o += " return true;\n"
o += "}\n"
cpp += o
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
@@ -1394,10 +1373,7 @@ def build_base_class(
# For base classes, we only declare the fields but don't handle encode/decode
# The derived classes will handle encoding/decoding with their specific field numbers
for field in common_fields:
if field.label == 3: # repeated
ti = RepeatedTypeInfo(field)
else:
ti = TYPE_INFO[field.type](field)
ti = get_type_info_for_field(field)
# Only add field declarations, not encode/decode logic
protected_content.extend(ti.protected_content)