1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-01 10:52:19 +01:00

[api] Align ProtoSize API design with ProtoWriteBuffer pattern (#9920)

This commit is contained in:
J. Nick Koston
2025-07-28 12:28:32 -10:00
committed by GitHub
parent 68f388f78e
commit 2c9987869e
5 changed files with 795 additions and 865 deletions

View File

@@ -275,13 +275,13 @@ class TypeInfo(ABC):
Args:
name: Field name
force: Whether this is for a repeated field
base_method: Base method name (e.g., "add_int32_field")
base_method: Base method name (e.g., "add_int32")
value_expr: Optional value expression (defaults to name)
"""
field_id_size = self.calculate_field_id_size()
method = f"{base_method}_repeated" if force else base_method
method = f"{base_method}_force" if force else base_method
value = value_expr if value_expr else name
return f"ProtoSize::{method}(total_size, {field_id_size}, {value});"
return f"size.{method}({field_id_size}, {value});"
@abstractmethod
def get_size_calculation(self, name: str, force: bool = False) -> str:
@@ -389,7 +389,7 @@ class DoubleType(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_double_field(total_size, {field_id_size}, {name});"
return f"size.add_double({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 8
@@ -413,7 +413,7 @@ class FloatType(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_float_field(total_size, {field_id_size}, {name});"
return f"size.add_float({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 4
@@ -436,7 +436,7 @@ class Int64Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_int64_field")
return self._get_simple_size_calculation(name, force, "add_int64")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -456,7 +456,7 @@ class UInt64Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_uint64_field")
return self._get_simple_size_calculation(name, force, "add_uint64")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -476,7 +476,7 @@ class Int32Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_int32_field")
return self._get_simple_size_calculation(name, force, "add_int32")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -497,7 +497,7 @@ class Fixed64Type(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_fixed64_field(total_size, {field_id_size}, {name});"
return f"size.add_fixed64({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 8
@@ -521,7 +521,7 @@ class Fixed32Type(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_fixed32_field(total_size, {field_id_size}, {name});"
return f"size.add_fixed32({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 4
@@ -543,7 +543,7 @@ class BoolType(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_bool_field")
return self._get_simple_size_calculation(name, force, "add_bool")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 1 # field ID + 1 byte
@@ -657,19 +657,21 @@ class StringType(TypeInfo):
# For no_zero_copy, we need to use .size() on the string
if no_zero_copy and name != "it":
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_string_field(total_size, {field_id_size}, this->{self.field_name}.size());"
return self._get_simple_size_calculation(name, force, "add_string_field")
return (
f"size.add_length({field_id_size}, this->{self.field_name}.size());"
)
return self._get_simple_size_calculation(name, force, "add_length")
# Check if this is being called from a repeated field context
# In that case, 'name' will be 'it' and we need to use the repeated version
if name == "it":
# For repeated fields, we need to use add_string_field_repeated which includes field ID
# For repeated fields, we need to use add_length_force which includes field ID
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_string_field_repeated(total_size, {field_id_size}, it);"
return f"size.add_length_force({field_id_size}, it.size());"
# For messages that need encoding, use the StringRef size
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_string_field(total_size, {field_id_size}, this->{self.field_name}_ref_.size());"
return f"size.add_length({field_id_size}, this->{self.field_name}_ref_.size());"
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical string
@@ -804,7 +806,7 @@ class BytesType(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return f"ProtoSize::add_bytes_field(total_size, {self.calculate_field_id_size()}, this->{self.field_name}_len_);"
return f"size.add_length({self.calculate_field_id_size()}, this->{self.field_name}_len_);"
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical bytes
@@ -879,15 +881,11 @@ class FixedArrayBytesType(TypeInfo):
field_id_size = self.calculate_field_id_size()
if force:
# For repeated fields, always calculate size
return f"total_size += {field_id_size} + ProtoSize::varint(static_cast<uint32_t>({length_field})) + {length_field};"
# For repeated fields, always calculate size (no zero check)
return f"size.add_length_force({field_id_size}, {length_field});"
else:
# For non-repeated fields, skip if length is 0 (matching encode_string behavior)
return (
f"if ({length_field} != 0) {{\n"
f" total_size += {field_id_size} + ProtoSize::varint(static_cast<uint32_t>({length_field})) + {length_field};\n"
f"}}"
)
# For non-repeated fields, add_length already checks for zero
return f"size.add_length({field_id_size}, {length_field});"
def get_estimated_size(self) -> int:
# Estimate based on typical BLE advertisement size
@@ -914,7 +912,7 @@ class UInt32Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_uint32_field")
return self._get_simple_size_calculation(name, force, "add_uint32")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -951,7 +949,7 @@ class EnumType(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(
name, force, "add_enum_field", f"static_cast<uint32_t>({name})"
name, force, "add_uint32", f"static_cast<uint32_t>({name})"
)
def get_estimated_size(self) -> int:
@@ -973,7 +971,7 @@ class SFixed32Type(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_sfixed32_field(total_size, {field_id_size}, {name});"
return f"size.add_sfixed32({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 4
@@ -997,7 +995,7 @@ class SFixed64Type(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_sfixed64_field(total_size, {field_id_size}, {name});"
return f"size.add_sfixed64({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 8
@@ -1020,7 +1018,7 @@ class SInt32Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_sint32_field")
return self._get_simple_size_calculation(name, force, "add_sint32")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -1040,7 +1038,7 @@ class SInt64Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_sint64_field")
return self._get_simple_size_calculation(name, force, "add_sint64")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -1274,7 +1272,7 @@ class RepeatedTypeInfo(TypeInfo):
if isinstance(self._ti, MessageType):
# For repeated messages, use the dedicated helper that handles iteration internally
field_id_size = self._ti.calculate_field_id_size()
o = f"ProtoSize::add_repeated_message(total_size, {field_id_size}, {name});"
o = f"size.add_repeated_message({field_id_size}, {name});"
return o
# For other repeated types, use the underlying type's size calculation with force=True
@@ -1287,7 +1285,9 @@ class RepeatedTypeInfo(TypeInfo):
field_id_size = self._ti.calculate_field_id_size()
# Pre-calculate the total bytes per element
bytes_per_element = field_id_size + num_bytes
o += f" total_size += {name}.size() * {bytes_per_element};\n"
o += (
f" size.add_precalculated_size({name}.size() * {bytes_per_element});\n"
)
else:
# Other types need the actual value
o += f" for (const auto {'' if self._ti_is_bool else '&'}it : {name}) {{\n"
@@ -1719,11 +1719,11 @@ def build_message_type(
if needs_encode and encode:
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
o += f" {encode[0]} "
o += f" {encode[0]} }}\n"
else:
o += "\n"
o += indent("\n".join(encode)) + "\n"
o += "}\n"
o += "}\n"
cpp += o
prot = "void encode(ProtoWriteBuffer buffer) const override;"
public_content.append(prot)
@@ -1731,17 +1731,17 @@ def build_message_type(
# Add calculate_size method only if this message needs encoding and has fields
if needs_encode and size_calc:
o = f"void {desc.name}::calculate_size(uint32_t &total_size) const {{"
o = f"void {desc.name}::calculate_size(ProtoSize &size) const {{"
# For a single field, just inline it for simplicity
if len(size_calc) == 1 and len(size_calc[0]) + len(o) + 3 < 120:
o += f" {size_calc[0]} "
o += f" {size_calc[0]} }}\n"
else:
# For multiple fields
o += "\n"
o += indent("\n".join(size_calc)) + "\n"
o += "}\n"
o += "}\n"
cpp += o
prot = "void calculate_size(uint32_t &total_size) const override;"
prot = "void calculate_size(ProtoSize &size) const override;"
public_content.append(prot)
# If no fields to calculate size for or message doesn't need encoding, the default implementation in ProtoMessage will be used