mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 15:12:06 +00:00 
			
		
		
		
	Fix bluetooth_proxy heap allocations during BLE scanning (#9633)
This commit is contained in:
		| @@ -313,13 +313,18 @@ def validate_field_type(field_type: int, field_name: str = "") -> None: | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def get_type_info_for_field(field: descriptor.FieldDescriptorProto) -> TypeInfo: | ||||
|     """Get the appropriate TypeInfo for a field, handling repeated fields. | ||||
|  | ||||
|     Also validates that the field type is supported. | ||||
|     """ | ||||
| def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo: | ||||
|     """Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options.""" | ||||
|     if field.label == 3:  # repeated | ||||
|         return RepeatedTypeInfo(field) | ||||
|  | ||||
|     # Check for fixed_array_size option on bytes fields | ||||
|     if ( | ||||
|         field.type == 12 | ||||
|         and (fixed_size := get_field_opt(field, pb.fixed_array_size)) is not None | ||||
|     ): | ||||
|         return FixedArrayBytesType(field, fixed_size) | ||||
|  | ||||
|     validate_field_type(field.type, field.name) | ||||
|     return TYPE_INFO[field.type](field) | ||||
|  | ||||
| @@ -603,6 +608,85 @@ class BytesType(TypeInfo): | ||||
|         return self.calculate_field_id_size() + 8  # field ID + 8 bytes typical bytes | ||||
|  | ||||
|  | ||||
| class FixedArrayBytesType(TypeInfo): | ||||
|     """Special type for fixed-size byte arrays.""" | ||||
|  | ||||
|     def __init__(self, field: descriptor.FieldDescriptorProto, size: int) -> None: | ||||
|         super().__init__(field) | ||||
|         self.array_size = size | ||||
|  | ||||
|     @property | ||||
|     def cpp_type(self) -> str: | ||||
|         return "uint8_t" | ||||
|  | ||||
|     @property | ||||
|     def default_value(self) -> str: | ||||
|         return "{}" | ||||
|  | ||||
|     @property | ||||
|     def reference_type(self) -> str: | ||||
|         return f"uint8_t (&)[{self.array_size}]" | ||||
|  | ||||
|     @property | ||||
|     def const_reference_type(self) -> str: | ||||
|         return f"const uint8_t (&)[{self.array_size}]" | ||||
|  | ||||
|     @property | ||||
|     def public_content(self) -> list[str]: | ||||
|         # Add both the array and length fields | ||||
|         return [ | ||||
|             f"uint8_t {self.field_name}[{self.array_size}]{{}};", | ||||
|             f"uint8_t {self.field_name}_len{{0}};", | ||||
|         ] | ||||
|  | ||||
|     @property | ||||
|     def decode_length_content(self) -> str: | ||||
|         o = f"case {self.number}: {{\n" | ||||
|         o += "  const std::string &data_str = value.as_string();\n" | ||||
|         o += f"  this->{self.field_name}_len = data_str.size();\n" | ||||
|         o += f"  if (this->{self.field_name}_len > {self.array_size}) {{\n" | ||||
|         o += f"    this->{self.field_name}_len = {self.array_size};\n" | ||||
|         o += "  }\n" | ||||
|         o += f"  memcpy(this->{self.field_name}, data_str.data(), this->{self.field_name}_len);\n" | ||||
|         o += "  break;\n" | ||||
|         o += "}" | ||||
|         return o | ||||
|  | ||||
|     @property | ||||
|     def encode_content(self) -> str: | ||||
|         return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len);" | ||||
|  | ||||
|     def dump(self, name: str) -> str: | ||||
|         o = f"out.append(format_hex_pretty({name}, {name}_len));" | ||||
|         return o | ||||
|  | ||||
|     def get_size_calculation(self, name: str, force: bool = False) -> str: | ||||
|         # Use the actual length stored in the _len field | ||||
|         length_field = f"this->{self.field_name}_len" | ||||
|         field_id_size = self.calculate_field_id_size() | ||||
|  | ||||
|         if force: | ||||
|             # For repeated fields, always calculate size | ||||
|             return f"total_size += {field_id_size} + ProtoSize::varint(static_cast<uint32_t>({length_field})) + {length_field};" | ||||
|         else: | ||||
|             # For non-repeated fields, skip if length is 0 (matching encode_string behavior) | ||||
|             return ( | ||||
|                 f"if ({length_field} != 0) {{\n" | ||||
|                 f"  total_size += {field_id_size} + ProtoSize::varint(static_cast<uint32_t>({length_field})) + {length_field};\n" | ||||
|                 f"}}" | ||||
|             ) | ||||
|  | ||||
|     def get_estimated_size(self) -> int: | ||||
|         # Estimate based on typical BLE advertisement size | ||||
|         return ( | ||||
|             self.calculate_field_id_size() + 1 + 31 | ||||
|         )  # field ID + length byte + typical 31 bytes | ||||
|  | ||||
|     @property | ||||
|     def wire_type(self) -> WireType: | ||||
|         return WireType.LENGTH_DELIMITED | ||||
|  | ||||
|  | ||||
| @register_type(13) | ||||
| class UInt32Type(TypeInfo): | ||||
|     cpp_type = "uint32_t" | ||||
| @@ -748,6 +832,16 @@ class SInt64Type(TypeInfo): | ||||
| class RepeatedTypeInfo(TypeInfo): | ||||
|     def __init__(self, field: descriptor.FieldDescriptorProto) -> None: | ||||
|         super().__init__(field) | ||||
|         # For repeated fields, we need to get the base type info | ||||
|         # but we can't call create_field_type_info as it would cause recursion | ||||
|         # So we extract just the type creation logic | ||||
|         if ( | ||||
|             field.type == 12 | ||||
|             and (fixed_size := get_field_opt(field, pb.fixed_array_size)) is not None | ||||
|         ): | ||||
|             self._ti: TypeInfo = FixedArrayBytesType(field, fixed_size) | ||||
|             return | ||||
|  | ||||
|         validate_field_type(field.type, field.name) | ||||
|         self._ti: TypeInfo = TYPE_INFO[field.type](field) | ||||
|  | ||||
| @@ -1051,7 +1145,7 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int: | ||||
|     total_size = 0 | ||||
|  | ||||
|     for field in desc.field: | ||||
|         ti = get_type_info_for_field(field) | ||||
|         ti = create_field_type_info(field) | ||||
|  | ||||
|         # Add estimated size for this field | ||||
|         total_size += ti.get_estimated_size() | ||||
| @@ -1119,10 +1213,7 @@ def build_message_type( | ||||
|         public_content.append("#endif") | ||||
|  | ||||
|     for field in desc.field: | ||||
|         if field.label == 3: | ||||
|             ti = RepeatedTypeInfo(field) | ||||
|         else: | ||||
|             ti = TYPE_INFO[field.type](field) | ||||
|         ti = create_field_type_info(field) | ||||
|  | ||||
|         # Skip field declarations for fields that are in the base class | ||||
|         # but include their encode/decode logic | ||||
| @@ -1327,6 +1418,17 @@ def get_opt( | ||||
|     return desc.options.Extensions[opt] | ||||
|  | ||||
|  | ||||
| def get_field_opt( | ||||
|     field: descriptor.FieldDescriptorProto, | ||||
|     opt: descriptor.FieldOptions, | ||||
|     default: Any = None, | ||||
| ) -> Any: | ||||
|     """Get the option from a field descriptor.""" | ||||
|     if not field.options.HasExtension(opt): | ||||
|         return default | ||||
|     return field.options.Extensions[opt] | ||||
|  | ||||
|  | ||||
| def get_base_class(desc: descriptor.DescriptorProto) -> str | None: | ||||
|     """Get the base_class option from a message descriptor.""" | ||||
|     if not desc.options.HasExtension(pb.base_class): | ||||
| @@ -1401,7 +1503,7 @@ def build_base_class( | ||||
|     # For base classes, we only declare the fields but don't handle encode/decode | ||||
|     # The derived classes will handle encoding/decoding with their specific field numbers | ||||
|     for field in common_fields: | ||||
|         ti = get_type_info_for_field(field) | ||||
|         ti = create_field_type_info(field) | ||||
|  | ||||
|         # Only add field declarations, not encode/decode logic | ||||
|         protected_content.extend(ti.protected_content) | ||||
| @@ -1543,6 +1645,7 @@ namespace api { | ||||
|     #include "api_pb2.h" | ||||
|     #include "esphome/core/log.h" | ||||
|     #include "esphome/core/helpers.h" | ||||
|     #include <cstring> | ||||
|  | ||||
| namespace esphome { | ||||
| namespace api { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user