mirror of
https://github.com/esphome/esphome.git
synced 2025-09-01 10:52:19 +01:00
[api] Optimize protobuf memory usage with fixed-size arrays for Bluetooth UUIDs (#9782)
This commit is contained in:
@@ -327,6 +327,9 @@ def create_field_type_info(
|
||||
) -> TypeInfo:
|
||||
"""Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options."""
|
||||
if field.label == 3: # repeated
|
||||
# Check if this repeated field has fixed_array_size option
|
||||
if (fixed_size := get_field_opt(field, pb.fixed_array_size)) is not None:
|
||||
return FixedArrayRepeatedType(field, fixed_size)
|
||||
return RepeatedTypeInfo(field)
|
||||
|
||||
# Check for fixed_array_size option on bytes fields
|
||||
@@ -593,6 +596,8 @@ class MessageType(TypeInfo):
|
||||
return self._get_simple_size_calculation(name, force, "add_message_object")
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
# For message types, we can't easily estimate the submessage size without
|
||||
# access to the actual message definition. This is just a rough estimate.
|
||||
return (
|
||||
self.calculate_field_id_size() + 16
|
||||
) # field ID + 16 bytes estimated submessage
|
||||
@@ -883,6 +888,111 @@ class SInt64Type(TypeInfo):
|
||||
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
|
||||
|
||||
|
||||
class FixedArrayRepeatedType(TypeInfo):
|
||||
"""Special type for fixed-size repeated fields using std::array.
|
||||
|
||||
Fixed arrays are only supported for encoding (SOURCE_SERVER) since we cannot
|
||||
control how many items we receive when decoding.
|
||||
"""
|
||||
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto, size: int) -> None:
|
||||
super().__init__(field)
|
||||
self.array_size = size
|
||||
# Create the element type info
|
||||
validate_field_type(field.type, field.name)
|
||||
self._ti: TypeInfo = TYPE_INFO[field.type](field)
|
||||
|
||||
@property
|
||||
def cpp_type(self) -> str:
|
||||
return f"std::array<{self._ti.cpp_type}, {self.array_size}>"
|
||||
|
||||
@property
|
||||
def reference_type(self) -> str:
|
||||
return f"{self.cpp_type} &"
|
||||
|
||||
@property
|
||||
def const_reference_type(self) -> str:
|
||||
return f"const {self.cpp_type} &"
|
||||
|
||||
@property
|
||||
def wire_type(self) -> WireType:
|
||||
"""Get the wire type for this fixed array field."""
|
||||
return self._ti.wire_type
|
||||
|
||||
@property
|
||||
def public_content(self) -> list[str]:
|
||||
# Just the array member, no index needed since we don't decode
|
||||
return [f"{self.cpp_type} {self.field_name}{{}};"]
|
||||
|
||||
# No decode methods needed - fixed arrays don't support decoding
|
||||
# The base class TypeInfo already returns None for all decode properties
|
||||
|
||||
@property
|
||||
def encode_content(self) -> str:
|
||||
# Helper to generate encode statement for a single element
|
||||
def encode_element(element: str) -> str:
|
||||
if isinstance(self._ti, EnumType):
|
||||
return f"buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>({element}), true);"
|
||||
else:
|
||||
return f"buffer.{self._ti.encode_func}({self.number}, {element}, true);"
|
||||
|
||||
# Unroll small arrays for efficiency
|
||||
if self.array_size == 1:
|
||||
return encode_element(f"this->{self.field_name}[0]")
|
||||
elif self.array_size == 2:
|
||||
return (
|
||||
encode_element(f"this->{self.field_name}[0]")
|
||||
+ "\n "
|
||||
+ encode_element(f"this->{self.field_name}[1]")
|
||||
)
|
||||
|
||||
# Use loops for larger arrays
|
||||
o = f"for (const auto &it : this->{self.field_name}) {{\n"
|
||||
o += f" {encode_element('it')}\n"
|
||||
o += "}"
|
||||
return o
|
||||
|
||||
@property
|
||||
def dump_content(self) -> str:
|
||||
o = f"for (const auto &it : this->{self.field_name}) {{\n"
|
||||
o += f' out.append(" {self.name}: ");\n'
|
||||
o += indent(self._ti.dump("it")) + "\n"
|
||||
o += ' out.append("\\n");\n'
|
||||
o += "}\n"
|
||||
return o
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
# This is used when dumping the array itself (not its elements)
|
||||
# Since dump_content handles the iteration, this is not used directly
|
||||
return ""
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
# For fixed arrays, we always encode all elements
|
||||
|
||||
# Special case for single-element arrays - no loop needed
|
||||
if self.array_size == 1:
|
||||
return self._ti.get_size_calculation(f"{name}[0]", True)
|
||||
|
||||
# Special case for 2-element arrays - unroll the calculation
|
||||
if self.array_size == 2:
|
||||
return (
|
||||
self._ti.get_size_calculation(f"{name}[0]", True)
|
||||
+ "\n "
|
||||
+ self._ti.get_size_calculation(f"{name}[1]", True)
|
||||
)
|
||||
|
||||
# Use loops for larger arrays
|
||||
o = f"for (const auto &it : {name}) {{\n"
|
||||
o += f" {self._ti.get_size_calculation('it', True)}\n"
|
||||
o += "}"
|
||||
return o
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
# For fixed arrays, estimate underlying type size * array size
|
||||
underlying_size = self._ti.get_estimated_size()
|
||||
return underlying_size * self.array_size
|
||||
|
||||
|
||||
class RepeatedTypeInfo(TypeInfo):
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||
super().__init__(field)
|
||||
@@ -1311,6 +1421,19 @@ def build_message_type(
|
||||
if field.options.deprecated:
|
||||
continue
|
||||
|
||||
# Validate that fixed_array_size is only used in encode-only messages
|
||||
if (
|
||||
needs_decode
|
||||
and field.label == 3
|
||||
and get_field_opt(field, pb.fixed_array_size) is not None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Message '{desc.name}' uses fixed_array_size on field '{field.name}' "
|
||||
f"but has source={SOURCE_NAMES[source]}. "
|
||||
f"Fixed arrays are only supported for SOURCE_SERVER (encode-only) messages "
|
||||
f"since we cannot trust or control the number of items received from clients."
|
||||
)
|
||||
|
||||
ti = create_field_type_info(field, needs_decode, needs_encode)
|
||||
|
||||
# Skip field declarations for fields that are in the base class
|
||||
@@ -1500,6 +1623,12 @@ SOURCE_BOTH = 0
|
||||
SOURCE_SERVER = 1
|
||||
SOURCE_CLIENT = 2
|
||||
|
||||
SOURCE_NAMES = {
|
||||
SOURCE_BOTH: "SOURCE_BOTH",
|
||||
SOURCE_SERVER: "SOURCE_SERVER",
|
||||
SOURCE_CLIENT: "SOURCE_CLIENT",
|
||||
}
|
||||
|
||||
RECEIVE_CASES: dict[int, tuple[str, str | None]] = {}
|
||||
|
||||
ifdefs: dict[str, str] = {}
|
||||
|
Reference in New Issue
Block a user