1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-01 10:52:19 +01:00

[api] Implement zero-copy for all protobuf bytes fields (#9761)

This commit is contained in:
J. Nick Koston
2025-07-21 11:38:39 -10:00
committed by GitHub
parent 74ce3d2c0b
commit db62a94712
9 changed files with 154 additions and 108 deletions

View File

@@ -113,8 +113,15 @@ def force_str(force: bool) -> str:
class TypeInfo(ABC):
"""Base class for all type information."""
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
def __init__(
self,
field: descriptor.FieldDescriptorProto,
needs_decode: bool = True,
needs_encode: bool = True,
) -> None:
self._field = field
self._needs_decode = needs_decode
self._needs_encode = needs_encode
@property
def default_value(self) -> str:
@@ -313,7 +320,11 @@ def validate_field_type(field_type: int, field_name: str = "") -> None:
)
def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo:
def create_field_type_info(
field: descriptor.FieldDescriptorProto,
needs_decode: bool = True,
needs_encode: bool = True,
) -> TypeInfo:
"""Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options."""
if field.label == 3: # repeated
return RepeatedTypeInfo(field)
@@ -325,6 +336,10 @@ def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo:
):
return FixedArrayBytesType(field, fixed_size)
# Special handling for bytes fields
if field.type == 12:
return BytesType(field, needs_decode, needs_encode)
validate_field_type(field.type, field.name)
return TYPE_INFO[field.type](field)
@@ -589,20 +604,59 @@ class BytesType(TypeInfo):
default_value = ""
reference_type = "std::string &"
const_reference_type = "const std::string &"
decode_length = "value.as_string()"
encode_func = "encode_bytes"
decode_length = "value.as_string()"
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
@property
def public_content(self) -> list[str]:
content: list[str] = []
# Add std::string storage if message needs decoding
if self._needs_decode:
content.append(f"std::string {self.field_name}{{}};")
if self._needs_encode:
content.extend(
[
# Add pointer/length fields if message needs encoding
f"const uint8_t* {self.field_name}_ptr_{{nullptr}};",
f"size_t {self.field_name}_len_{{0}};",
# Add setter method if message needs encoding
f"void set_{self.field_name}(const uint8_t* data, size_t len) {{",
f" this->{self.field_name}_ptr_ = data;",
f" this->{self.field_name}_len_ = len;",
"}",
]
)
return content
@property
def encode_content(self) -> str:
return f"buffer.encode_bytes({self.number}, reinterpret_cast<const uint8_t*>(this->{self.field_name}.data()), this->{self.field_name}.size());"
return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);"
def dump(self, name: str) -> str:
o = f"out.append(format_hex_pretty({name}));"
return o
ptr_dump = f"format_hex_pretty(this->{self.field_name}_ptr_, this->{self.field_name}_len_)"
str_dump = f"format_hex_pretty(reinterpret_cast<const uint8_t*>(this->{self.field_name}.data()), this->{self.field_name}.size())"
# For SOURCE_CLIENT only, always use std::string
if not self._needs_encode:
return f"out.append({str_dump});"
# For SOURCE_SERVER, always use pointer/length
if not self._needs_decode:
return f"out.append({ptr_dump});"
# For SOURCE_BOTH, check if pointer is set (sending) or use string (received)
return (
f"if (this->{self.field_name}_ptr_ != nullptr) {{\n"
f" out.append({ptr_dump});\n"
f" }} else {{\n"
f" out.append({str_dump});\n"
f" }}"
)
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_string_field")
return f"ProtoSize::add_bytes_field(total_size, {self.calculate_field_id_size()}, this->{self.field_name}_len_);"
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical bytes
@@ -1257,7 +1311,7 @@ def build_message_type(
if field.options.deprecated:
continue
ti = create_field_type_info(field)
ti = create_field_type_info(field, needs_decode, needs_encode)
# Skip field declarations for fields that are in the base class
# but include their encode/decode logic
@@ -1572,10 +1626,20 @@ def build_base_class(
public_content = []
protected_content = []
# Determine if any message using this base class needs decoding/encoding
needs_decode = any(
message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
for msg in messages
)
needs_encode = any(
message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_SERVER)
for msg in messages
)
# For base classes, we only declare the fields but don't handle encode/decode
# The derived classes will handle encoding/decoding with their specific field numbers
for field in common_fields:
ti = create_field_type_info(field)
ti = create_field_type_info(field, needs_decode, needs_encode)
# Get field_ifdef if it's consistent across all messages
field_ifdef = get_common_field_ifdef(field.name, messages)
@@ -1586,12 +1650,6 @@ def build_base_class(
if ti.public_content:
public_content.extend(wrap_with_ifdef(ti.public_content, field_ifdef))
# Determine if any message using this base class needs decoding
needs_decode = any(
message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
for msg in messages
)
# Build header
parent_class = "ProtoDecodableMessage" if needs_decode else "ProtoMessage"
out = f"class {base_class_name} : public {parent_class} {{\n"