mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 15:12:06 +00:00 
			
		
		
		
	[api] Implement zero-copy for all protobuf bytes fields (#9761)
This commit is contained in:
		| @@ -113,8 +113,15 @@ def force_str(force: bool) -> str: | ||||
| class TypeInfo(ABC): | ||||
|     """Base class for all type information.""" | ||||
|  | ||||
|     def __init__(self, field: descriptor.FieldDescriptorProto) -> None: | ||||
|     def __init__( | ||||
|         self, | ||||
|         field: descriptor.FieldDescriptorProto, | ||||
|         needs_decode: bool = True, | ||||
|         needs_encode: bool = True, | ||||
|     ) -> None: | ||||
|         self._field = field | ||||
|         self._needs_decode = needs_decode | ||||
|         self._needs_encode = needs_encode | ||||
|  | ||||
|     @property | ||||
|     def default_value(self) -> str: | ||||
| @@ -313,7 +320,11 @@ def validate_field_type(field_type: int, field_name: str = "") -> None: | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo: | ||||
| def create_field_type_info( | ||||
|     field: descriptor.FieldDescriptorProto, | ||||
|     needs_decode: bool = True, | ||||
|     needs_encode: bool = True, | ||||
| ) -> TypeInfo: | ||||
|     """Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options.""" | ||||
|     if field.label == 3:  # repeated | ||||
|         return RepeatedTypeInfo(field) | ||||
| @@ -325,6 +336,10 @@ def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo: | ||||
|     ): | ||||
|         return FixedArrayBytesType(field, fixed_size) | ||||
|  | ||||
|     # Special handling for bytes fields | ||||
|     if field.type == 12: | ||||
|         return BytesType(field, needs_decode, needs_encode) | ||||
|  | ||||
|     validate_field_type(field.type, field.name) | ||||
|     return TYPE_INFO[field.type](field) | ||||
|  | ||||
| @@ -589,20 +604,59 @@ class BytesType(TypeInfo): | ||||
|     default_value = "" | ||||
|     reference_type = "std::string &" | ||||
|     const_reference_type = "const std::string &" | ||||
|     decode_length = "value.as_string()" | ||||
|     encode_func = "encode_bytes" | ||||
|     decode_length = "value.as_string()" | ||||
|     wire_type = WireType.LENGTH_DELIMITED  # Uses wire type 2 | ||||
|  | ||||
|     @property | ||||
|     def public_content(self) -> list[str]: | ||||
|         content: list[str] = [] | ||||
|         # Add std::string storage if message needs decoding | ||||
|         if self._needs_decode: | ||||
|             content.append(f"std::string {self.field_name}{{}};") | ||||
|  | ||||
|         if self._needs_encode: | ||||
|             content.extend( | ||||
|                 [ | ||||
|                     # Add pointer/length fields if message needs encoding | ||||
|                     f"const uint8_t* {self.field_name}_ptr_{{nullptr}};", | ||||
|                     f"size_t {self.field_name}_len_{{0}};", | ||||
|                     # Add setter method if message needs encoding | ||||
|                     f"void set_{self.field_name}(const uint8_t* data, size_t len) {{", | ||||
|                     f"  this->{self.field_name}_ptr_ = data;", | ||||
|                     f"  this->{self.field_name}_len_ = len;", | ||||
|                     "}", | ||||
|                 ] | ||||
|             ) | ||||
|         return content | ||||
|  | ||||
|     @property | ||||
|     def encode_content(self) -> str: | ||||
|         return f"buffer.encode_bytes({self.number}, reinterpret_cast<const uint8_t*>(this->{self.field_name}.data()), this->{self.field_name}.size());" | ||||
|         return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);" | ||||
|  | ||||
|     def dump(self, name: str) -> str: | ||||
|         o = f"out.append(format_hex_pretty({name}));" | ||||
|         return o | ||||
|         ptr_dump = f"format_hex_pretty(this->{self.field_name}_ptr_, this->{self.field_name}_len_)" | ||||
|         str_dump = f"format_hex_pretty(reinterpret_cast<const uint8_t*>(this->{self.field_name}.data()), this->{self.field_name}.size())" | ||||
|  | ||||
|         # For SOURCE_CLIENT only, always use std::string | ||||
|         if not self._needs_encode: | ||||
|             return f"out.append({str_dump});" | ||||
|  | ||||
|         # For SOURCE_SERVER, always use pointer/length | ||||
|         if not self._needs_decode: | ||||
|             return f"out.append({ptr_dump});" | ||||
|  | ||||
|         # For SOURCE_BOTH, check if pointer is set (sending) or use string (received) | ||||
|         return ( | ||||
|             f"if (this->{self.field_name}_ptr_ != nullptr) {{\n" | ||||
|             f"    out.append({ptr_dump});\n" | ||||
|             f"  }} else {{\n" | ||||
|             f"    out.append({str_dump});\n" | ||||
|             f"  }}" | ||||
|         ) | ||||
|  | ||||
|     def get_size_calculation(self, name: str, force: bool = False) -> str: | ||||
|         return self._get_simple_size_calculation(name, force, "add_string_field") | ||||
|         return f"ProtoSize::add_bytes_field(total_size, {self.calculate_field_id_size()}, this->{self.field_name}_len_);" | ||||
|  | ||||
|     def get_estimated_size(self) -> int: | ||||
|         return self.calculate_field_id_size() + 8  # field ID + 8 bytes typical bytes | ||||
| @@ -1257,7 +1311,7 @@ def build_message_type( | ||||
|         if field.options.deprecated: | ||||
|             continue | ||||
|  | ||||
|         ti = create_field_type_info(field) | ||||
|         ti = create_field_type_info(field, needs_decode, needs_encode) | ||||
|  | ||||
|         # Skip field declarations for fields that are in the base class | ||||
|         # but include their encode/decode logic | ||||
| @@ -1572,10 +1626,20 @@ def build_base_class( | ||||
|     public_content = [] | ||||
|     protected_content = [] | ||||
|  | ||||
|     # Determine if any message using this base class needs decoding/encoding | ||||
|     needs_decode = any( | ||||
|         message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT) | ||||
|         for msg in messages | ||||
|     ) | ||||
|     needs_encode = any( | ||||
|         message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_SERVER) | ||||
|         for msg in messages | ||||
|     ) | ||||
|  | ||||
|     # For base classes, we only declare the fields but don't handle encode/decode | ||||
|     # The derived classes will handle encoding/decoding with their specific field numbers | ||||
|     for field in common_fields: | ||||
|         ti = create_field_type_info(field) | ||||
|         ti = create_field_type_info(field, needs_decode, needs_encode) | ||||
|  | ||||
|         # Get field_ifdef if it's consistent across all messages | ||||
|         field_ifdef = get_common_field_ifdef(field.name, messages) | ||||
| @@ -1586,12 +1650,6 @@ def build_base_class( | ||||
|         if ti.public_content: | ||||
|             public_content.extend(wrap_with_ifdef(ti.public_content, field_ifdef)) | ||||
|  | ||||
|     # Determine if any message using this base class needs decoding | ||||
|     needs_decode = any( | ||||
|         message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT) | ||||
|         for msg in messages | ||||
|     ) | ||||
|  | ||||
|     # Build header | ||||
|     parent_class = "ProtoDecodableMessage" if needs_decode else "ProtoMessage" | ||||
|     out = f"class {base_class_name} : public {parent_class} {{\n" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user