mirror of
https://github.com/esphome/esphome.git
synced 2025-09-01 10:52:19 +01:00
Fix bluetooth_proxy heap allocations during BLE scanning (#9633)
This commit is contained in:
@@ -313,13 +313,18 @@ def validate_field_type(field_type: int, field_name: str = "") -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_type_info_for_field(field: descriptor.FieldDescriptorProto) -> TypeInfo:
|
||||
"""Get the appropriate TypeInfo for a field, handling repeated fields.
|
||||
|
||||
Also validates that the field type is supported.
|
||||
"""
|
||||
def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo:
|
||||
"""Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options."""
|
||||
if field.label == 3: # repeated
|
||||
return RepeatedTypeInfo(field)
|
||||
|
||||
# Check for fixed_array_size option on bytes fields
|
||||
if (
|
||||
field.type == 12
|
||||
and (fixed_size := get_field_opt(field, pb.fixed_array_size)) is not None
|
||||
):
|
||||
return FixedArrayBytesType(field, fixed_size)
|
||||
|
||||
validate_field_type(field.type, field.name)
|
||||
return TYPE_INFO[field.type](field)
|
||||
|
||||
@@ -603,6 +608,85 @@ class BytesType(TypeInfo):
|
||||
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical bytes
|
||||
|
||||
|
||||
class FixedArrayBytesType(TypeInfo):
|
||||
"""Special type for fixed-size byte arrays."""
|
||||
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto, size: int) -> None:
|
||||
super().__init__(field)
|
||||
self.array_size = size
|
||||
|
||||
@property
|
||||
def cpp_type(self) -> str:
|
||||
return "uint8_t"
|
||||
|
||||
@property
|
||||
def default_value(self) -> str:
|
||||
return "{}"
|
||||
|
||||
@property
|
||||
def reference_type(self) -> str:
|
||||
return f"uint8_t (&)[{self.array_size}]"
|
||||
|
||||
@property
|
||||
def const_reference_type(self) -> str:
|
||||
return f"const uint8_t (&)[{self.array_size}]"
|
||||
|
||||
@property
|
||||
def public_content(self) -> list[str]:
|
||||
# Add both the array and length fields
|
||||
return [
|
||||
f"uint8_t {self.field_name}[{self.array_size}]{{}};",
|
||||
f"uint8_t {self.field_name}_len{{0}};",
|
||||
]
|
||||
|
||||
@property
|
||||
def decode_length_content(self) -> str:
|
||||
o = f"case {self.number}: {{\n"
|
||||
o += " const std::string &data_str = value.as_string();\n"
|
||||
o += f" this->{self.field_name}_len = data_str.size();\n"
|
||||
o += f" if (this->{self.field_name}_len > {self.array_size}) {{\n"
|
||||
o += f" this->{self.field_name}_len = {self.array_size};\n"
|
||||
o += " }\n"
|
||||
o += f" memcpy(this->{self.field_name}, data_str.data(), this->{self.field_name}_len);\n"
|
||||
o += " break;\n"
|
||||
o += "}"
|
||||
return o
|
||||
|
||||
@property
|
||||
def encode_content(self) -> str:
|
||||
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len);"
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f"out.append(format_hex_pretty({name}, {name}_len));"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
# Use the actual length stored in the _len field
|
||||
length_field = f"this->{self.field_name}_len"
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
|
||||
if force:
|
||||
# For repeated fields, always calculate size
|
||||
return f"total_size += {field_id_size} + ProtoSize::varint(static_cast<uint32_t>({length_field})) + {length_field};"
|
||||
else:
|
||||
# For non-repeated fields, skip if length is 0 (matching encode_string behavior)
|
||||
return (
|
||||
f"if ({length_field} != 0) {{\n"
|
||||
f" total_size += {field_id_size} + ProtoSize::varint(static_cast<uint32_t>({length_field})) + {length_field};\n"
|
||||
f"}}"
|
||||
)
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
# Estimate based on typical BLE advertisement size
|
||||
return (
|
||||
self.calculate_field_id_size() + 1 + 31
|
||||
) # field ID + length byte + typical 31 bytes
|
||||
|
||||
@property
|
||||
def wire_type(self) -> WireType:
|
||||
return WireType.LENGTH_DELIMITED
|
||||
|
||||
|
||||
@register_type(13)
|
||||
class UInt32Type(TypeInfo):
|
||||
cpp_type = "uint32_t"
|
||||
@@ -748,6 +832,16 @@ class SInt64Type(TypeInfo):
|
||||
class RepeatedTypeInfo(TypeInfo):
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||
super().__init__(field)
|
||||
# For repeated fields, we need to get the base type info
|
||||
# but we can't call create_field_type_info as it would cause recursion
|
||||
# So we extract just the type creation logic
|
||||
if (
|
||||
field.type == 12
|
||||
and (fixed_size := get_field_opt(field, pb.fixed_array_size)) is not None
|
||||
):
|
||||
self._ti: TypeInfo = FixedArrayBytesType(field, fixed_size)
|
||||
return
|
||||
|
||||
validate_field_type(field.type, field.name)
|
||||
self._ti: TypeInfo = TYPE_INFO[field.type](field)
|
||||
|
||||
@@ -1051,7 +1145,7 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int:
|
||||
total_size = 0
|
||||
|
||||
for field in desc.field:
|
||||
ti = get_type_info_for_field(field)
|
||||
ti = create_field_type_info(field)
|
||||
|
||||
# Add estimated size for this field
|
||||
total_size += ti.get_estimated_size()
|
||||
@@ -1119,10 +1213,7 @@ def build_message_type(
|
||||
public_content.append("#endif")
|
||||
|
||||
for field in desc.field:
|
||||
if field.label == 3:
|
||||
ti = RepeatedTypeInfo(field)
|
||||
else:
|
||||
ti = TYPE_INFO[field.type](field)
|
||||
ti = create_field_type_info(field)
|
||||
|
||||
# Skip field declarations for fields that are in the base class
|
||||
# but include their encode/decode logic
|
||||
@@ -1327,6 +1418,17 @@ def get_opt(
|
||||
return desc.options.Extensions[opt]
|
||||
|
||||
|
||||
def get_field_opt(
|
||||
field: descriptor.FieldDescriptorProto,
|
||||
opt: descriptor.FieldOptions,
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
"""Get the option from a field descriptor."""
|
||||
if not field.options.HasExtension(opt):
|
||||
return default
|
||||
return field.options.Extensions[opt]
|
||||
|
||||
|
||||
def get_base_class(desc: descriptor.DescriptorProto) -> str | None:
|
||||
"""Get the base_class option from a message descriptor."""
|
||||
if not desc.options.HasExtension(pb.base_class):
|
||||
@@ -1401,7 +1503,7 @@ def build_base_class(
|
||||
# For base classes, we only declare the fields but don't handle encode/decode
|
||||
# The derived classes will handle encoding/decoding with their specific field numbers
|
||||
for field in common_fields:
|
||||
ti = get_type_info_for_field(field)
|
||||
ti = create_field_type_info(field)
|
||||
|
||||
# Only add field declarations, not encode/decode logic
|
||||
protected_content.extend(ti.protected_content)
|
||||
@@ -1543,6 +1645,7 @@ namespace api {
|
||||
#include "api_pb2.h"
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
#include <cstring>
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
Reference in New Issue
Block a user