mirror of
https://github.com/esphome/esphome.git
synced 2025-10-21 03:03:50 +01:00
[api] Use FixedVector for ExecuteServiceRequest/Argument arrays to eliminate reallocations (#11270)
This commit is contained in:
@@ -11,6 +11,7 @@ from typing import Any
|
||||
|
||||
import aioesphomeapi.api_options_pb2 as pb
|
||||
import google.protobuf.descriptor_pb2 as descriptor
|
||||
from google.protobuf.descriptor_pb2 import FieldDescriptorProto
|
||||
|
||||
|
||||
class WireType(IntEnum):
|
||||
@@ -148,7 +149,7 @@ class TypeInfo(ABC):
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
"""Check if the field is repeated."""
|
||||
return self._field.label == 3
|
||||
return self._field.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
|
||||
@property
|
||||
def wire_type(self) -> WireType:
|
||||
@@ -337,7 +338,7 @@ def create_field_type_info(
|
||||
needs_encode: bool = True,
|
||||
) -> TypeInfo:
|
||||
"""Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options."""
|
||||
if field.label == 3: # repeated
|
||||
if field.label == FieldDescriptorProto.LABEL_REPEATED:
|
||||
# Check if this repeated field has fixed_array_with_length_define option
|
||||
if (
|
||||
fixed_size := get_field_opt(field, pb.fixed_array_with_length_define)
|
||||
@@ -1879,6 +1880,9 @@ def build_message_type(
|
||||
)
|
||||
public_content.append("#endif")
|
||||
|
||||
# Collect fixed_vector fields for custom decode generation
|
||||
fixed_vector_fields = []
|
||||
|
||||
for field in desc.field:
|
||||
# Skip deprecated fields completely
|
||||
if field.options.deprecated:
|
||||
@@ -1887,7 +1891,7 @@ def build_message_type(
|
||||
# Validate that fixed_array_size is only used in encode-only messages
|
||||
if (
|
||||
needs_decode
|
||||
and field.label == 3
|
||||
and field.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
and get_field_opt(field, pb.fixed_array_size) is not None
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -1900,7 +1904,7 @@ def build_message_type(
|
||||
# Validate that fixed_array_with_length_define is only used in encode-only messages
|
||||
if (
|
||||
needs_decode
|
||||
and field.label == 3
|
||||
and field.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
and get_field_opt(field, pb.fixed_array_with_length_define) is not None
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -1910,6 +1914,14 @@ def build_message_type(
|
||||
f"since we cannot trust or control the number of items received from clients."
|
||||
)
|
||||
|
||||
# Collect fixed_vector repeated fields for custom decode generation
|
||||
if (
|
||||
needs_decode
|
||||
and field.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
and get_field_opt(field, pb.fixed_vector, False)
|
||||
):
|
||||
fixed_vector_fields.append((field.name, field.number))
|
||||
|
||||
ti = create_field_type_info(field, needs_decode, needs_encode)
|
||||
|
||||
# Skip field declarations for fields that are in the base class
|
||||
@@ -2018,6 +2030,22 @@ def build_message_type(
|
||||
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
|
||||
protected_content.insert(0, prot)
|
||||
|
||||
# Generate custom decode() override for messages with FixedVector fields
|
||||
if fixed_vector_fields:
|
||||
# Generate the decode() implementation in cpp
|
||||
o = f"void {desc.name}::decode(const uint8_t *buffer, size_t length) {{\n"
|
||||
# Count and init each FixedVector field
|
||||
for field_name, field_number in fixed_vector_fields:
|
||||
o += f" uint32_t count_{field_name} = ProtoDecodableMessage::count_repeated_field(buffer, length, {field_number});\n"
|
||||
o += f" this->{field_name}.init(count_{field_name});\n"
|
||||
# Call parent decode to populate the fields
|
||||
o += " ProtoDecodableMessage::decode(buffer, length);\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
# Generate the decode() declaration in header (public method)
|
||||
prot = "void decode(const uint8_t *buffer, size_t length) override;"
|
||||
public_content.append(prot)
|
||||
|
||||
# Only generate encode method if this message needs encoding and has fields
|
||||
if needs_encode and encode:
|
||||
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"
|
||||
|
Reference in New Issue
Block a user