mirror of
https://github.com/esphome/esphome.git
synced 2025-09-01 10:52:19 +01:00
[api] Eliminate heap allocations when populating repeated fields from containers (#9948)
This commit is contained in:
@@ -1170,6 +1170,10 @@ class FixedArrayRepeatedType(TypeInfo):
|
||||
class RepeatedTypeInfo(TypeInfo):
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||
super().__init__(field)
|
||||
# Check if this is a pointer field by looking for container_pointer option
|
||||
self._container_type = get_field_opt(field, pb.container_pointer, "")
|
||||
self._use_pointer = bool(self._container_type)
|
||||
|
||||
# For repeated fields, we need to get the base type info
|
||||
# but we can't call create_field_type_info as it would cause recursion
|
||||
# So we extract just the type creation logic
|
||||
@@ -1185,6 +1189,14 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
|
||||
@property
|
||||
def cpp_type(self) -> str:
|
||||
if self._use_pointer and self._container_type:
|
||||
# For pointer fields, use the specified container type
|
||||
# If the container type already includes the element type (e.g., std::set<climate::ClimateMode>)
|
||||
# use it as-is, otherwise append the element type
|
||||
if "<" in self._container_type and ">" in self._container_type:
|
||||
return f"const {self._container_type}*"
|
||||
else:
|
||||
return f"const {self._container_type}<{self._ti.cpp_type}>*"
|
||||
return f"std::vector<{self._ti.cpp_type}>"
|
||||
|
||||
@property
|
||||
@@ -1205,6 +1217,9 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
|
||||
@property
|
||||
def decode_varint_content(self) -> str:
|
||||
# Pointer fields don't support decoding
|
||||
if self._use_pointer:
|
||||
return None
|
||||
content = self._ti.decode_varint
|
||||
if content is None:
|
||||
return None
|
||||
@@ -1214,6 +1229,9 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
|
||||
@property
|
||||
def decode_length_content(self) -> str:
|
||||
# Pointer fields don't support decoding
|
||||
if self._use_pointer:
|
||||
return None
|
||||
content = self._ti.decode_length
|
||||
if content is None and isinstance(self._ti, MessageType):
|
||||
# Special handling for non-template message decoding
|
||||
@@ -1226,6 +1244,9 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
|
||||
@property
|
||||
def decode_32bit_content(self) -> str:
|
||||
# Pointer fields don't support decoding
|
||||
if self._use_pointer:
|
||||
return None
|
||||
content = self._ti.decode_32bit
|
||||
if content is None:
|
||||
return None
|
||||
@@ -1235,6 +1256,9 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
|
||||
@property
|
||||
def decode_64bit_content(self) -> str:
|
||||
# Pointer fields don't support decoding
|
||||
if self._use_pointer:
|
||||
return None
|
||||
content = self._ti.decode_64bit
|
||||
if content is None:
|
||||
return None
|
||||
@@ -1249,16 +1273,31 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
|
||||
@property
|
||||
def encode_content(self) -> str:
|
||||
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
|
||||
if isinstance(self._ti, EnumType):
|
||||
o += f" buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>(it), true);\n"
|
||||
if self._use_pointer:
|
||||
# For pointer fields, just dereference (pointer should never be null in our use case)
|
||||
o = f"for (const auto &it : *this->{self.field_name}) {{\n"
|
||||
if isinstance(self._ti, EnumType):
|
||||
o += f" buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>(it), true);\n"
|
||||
else:
|
||||
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
|
||||
o += "}"
|
||||
return o
|
||||
else:
|
||||
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
|
||||
o += "}"
|
||||
return o
|
||||
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
|
||||
if isinstance(self._ti, EnumType):
|
||||
o += f" buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>(it), true);\n"
|
||||
else:
|
||||
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
|
||||
o += "}"
|
||||
return o
|
||||
|
||||
@property
|
||||
def dump_content(self) -> str:
|
||||
if self._use_pointer:
|
||||
# For pointer fields, dereference and use the existing helper
|
||||
return _generate_array_dump_content(
|
||||
self._ti, f"*this->{self.field_name}", self.name, is_bool=False
|
||||
)
|
||||
return _generate_array_dump_content(
|
||||
self._ti, f"this->{self.field_name}", self.name, is_bool=self._ti_is_bool
|
||||
)
|
||||
@@ -1269,30 +1308,34 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
# For repeated fields, we always need to pass force=True to the underlying type's calculation
|
||||
# This is because the encode method always sets force=true for repeated fields
|
||||
|
||||
# Handle message types separately as they use a dedicated helper
|
||||
if isinstance(self._ti, MessageType):
|
||||
# For repeated messages, use the dedicated helper that handles iteration internally
|
||||
field_id_size = self._ti.calculate_field_id_size()
|
||||
o = f"size.add_repeated_message({field_id_size}, {name});"
|
||||
return o
|
||||
container = f"*{name}" if self._use_pointer else name
|
||||
return f"size.add_repeated_message({field_id_size}, {container});"
|
||||
|
||||
# For other repeated types, use the underlying type's size calculation with force=True
|
||||
o = f"if (!{name}.empty()) {{\n"
|
||||
# For non-message types, generate size calculation with iteration
|
||||
container_ref = f"*{name}" if self._use_pointer else name
|
||||
empty_check = f"{name}->empty()" if self._use_pointer else f"{name}.empty()"
|
||||
|
||||
# Check if this is a fixed-size type by seeing if it has a fixed byte count
|
||||
o = f"if (!{empty_check}) {{\n"
|
||||
|
||||
# Check if this is a fixed-size type
|
||||
num_bytes = self._ti.get_fixed_size_bytes()
|
||||
if num_bytes is not None:
|
||||
# Fixed types have constant size per element, so we can multiply
|
||||
# Fixed types have constant size per element
|
||||
field_id_size = self._ti.calculate_field_id_size()
|
||||
# Pre-calculate the total bytes per element
|
||||
bytes_per_element = field_id_size + num_bytes
|
||||
o += (
|
||||
f" size.add_precalculated_size({name}.size() * {bytes_per_element});\n"
|
||||
)
|
||||
size_expr = f"{name}->size()" if self._use_pointer else f"{name}.size()"
|
||||
o += f" size.add_precalculated_size({size_expr} * {bytes_per_element});\n"
|
||||
else:
|
||||
# Other types need the actual value
|
||||
o += f" for (const auto {'' if self._ti_is_bool else '&'}it : {name}) {{\n"
|
||||
auto_ref = "" if self._ti_is_bool else "&"
|
||||
o += f" for (const auto {auto_ref}it : {container_ref}) {{\n"
|
||||
o += f" {self._ti.get_size_calculation('it', True)}\n"
|
||||
o += " }\n"
|
||||
|
||||
o += "}"
|
||||
return o
|
||||
|
||||
@@ -2080,6 +2123,7 @@ def main() -> None:
|
||||
d = descriptor.FileDescriptorSet.FromString(proto_content)
|
||||
|
||||
file = d.file[0]
|
||||
|
||||
content = FILE_HEADER
|
||||
content += """\
|
||||
#pragma once
|
||||
@@ -2088,7 +2132,10 @@ def main() -> None:
|
||||
#include "esphome/core/string_ref.h"
|
||||
|
||||
#include "proto.h"
|
||||
#include "api_pb2_includes.h"
|
||||
"""
|
||||
|
||||
content += """
|
||||
namespace esphome::api {
|
||||
|
||||
"""
|
||||
|
Reference in New Issue
Block a user