1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-05 21:02:20 +01:00

Merge remote-tracking branch 'upstream/dev' into ruff_ret

This commit is contained in:
J. Nick Koston
2025-07-29 22:48:15 -10:00
83 changed files with 2942 additions and 1320 deletions

View File

@@ -275,13 +275,13 @@ class TypeInfo(ABC):
Args:
name: Field name
force: Whether this is for a repeated field
base_method: Base method name (e.g., "add_int32_field")
base_method: Base method name (e.g., "add_int32")
value_expr: Optional value expression (defaults to name)
"""
field_id_size = self.calculate_field_id_size()
method = f"{base_method}_repeated" if force else base_method
method = f"{base_method}_force" if force else base_method
value = value_expr if value_expr else name
return f"ProtoSize::{method}(total_size, {field_id_size}, {value});"
return f"size.{method}({field_id_size}, {value});"
@abstractmethod
def get_size_calculation(self, name: str, force: bool = False) -> str:
@@ -389,7 +389,7 @@ class DoubleType(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_double_field(total_size, {field_id_size}, {name});"
return f"size.add_double({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 8
@@ -413,7 +413,7 @@ class FloatType(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_float_field(total_size, {field_id_size}, {name});"
return f"size.add_float({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 4
@@ -436,7 +436,7 @@ class Int64Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_int64_field")
return self._get_simple_size_calculation(name, force, "add_int64")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -456,7 +456,7 @@ class UInt64Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_uint64_field")
return self._get_simple_size_calculation(name, force, "add_uint64")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -476,7 +476,7 @@ class Int32Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_int32_field")
return self._get_simple_size_calculation(name, force, "add_int32")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -497,7 +497,7 @@ class Fixed64Type(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_fixed64_field(total_size, {field_id_size}, {name});"
return f"size.add_fixed64({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 8
@@ -521,7 +521,7 @@ class Fixed32Type(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_fixed32_field(total_size, {field_id_size}, {name});"
return f"size.add_fixed32({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 4
@@ -542,7 +542,7 @@ class BoolType(TypeInfo):
return f"out.append(YESNO({name}));"
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_bool_field")
return self._get_simple_size_calculation(name, force, "add_bool")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 1 # field ID + 1 byte
@@ -561,11 +561,16 @@ class StringType(TypeInfo):
@property
def public_content(self) -> list[str]:
content: list[str] = []
# Add std::string storage if message needs decoding
if self._needs_decode:
# Check if no_zero_copy option is set
no_zero_copy = get_field_opt(self._field, pb.no_zero_copy, False)
# Add std::string storage if message needs decoding OR if no_zero_copy is set
if self._needs_decode or no_zero_copy:
content.append(f"std::string {self.field_name}{{}};")
if self._needs_encode:
# Only add StringRef if encoding is needed AND no_zero_copy is not set
if self._needs_encode and not no_zero_copy:
content.extend(
[
# Add StringRef field if message needs encoding
@@ -580,13 +585,27 @@ class StringType(TypeInfo):
@property
def encode_content(self) -> str:
# Check if no_zero_copy option is set
no_zero_copy = get_field_opt(self._field, pb.no_zero_copy, False)
if no_zero_copy:
# Use the std::string directly
return f"buffer.encode_string({self.number}, this->{self.field_name});"
# Use the StringRef
return f"buffer.encode_string({self.number}, this->{self.field_name}_ref_);"
def dump(self, name):
# Check if no_zero_copy option is set
no_zero_copy = get_field_opt(self._field, pb.no_zero_copy, False)
# If name is 'it', this is a repeated field element - always use string
if name == "it":
return "append_quoted_string(out, StringRef(it));"
# If no_zero_copy is set, always use std::string
if no_zero_copy:
return f'out.append("\'").append(this->{self.field_name}).append("\'");'
# For SOURCE_CLIENT only, always use std::string
if not self._needs_encode:
return f'out.append("\'").append(this->{self.field_name}).append("\'");'
@@ -606,6 +625,13 @@ class StringType(TypeInfo):
@property
def dump_content(self) -> str:
# Check if no_zero_copy option is set
no_zero_copy = get_field_opt(self._field, pb.no_zero_copy, False)
# If no_zero_copy is set, always use std::string
if no_zero_copy:
return f'dump_field(out, "{self.name}", this->{self.field_name});'
# For SOURCE_CLIENT only, use std::string
if not self._needs_encode:
return f'dump_field(out, "{self.name}", this->{self.field_name});'
@@ -621,20 +647,29 @@ class StringType(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
# For SOURCE_CLIENT only messages, use the string field directly
if not self._needs_encode:
return self._get_simple_size_calculation(name, force, "add_string_field")
# Check if no_zero_copy option is set
no_zero_copy = get_field_opt(self._field, pb.no_zero_copy, False)
# For SOURCE_CLIENT only messages or no_zero_copy, use the string field directly
if not self._needs_encode or no_zero_copy:
# For no_zero_copy, we need to use .size() on the string
if no_zero_copy and name != "it":
field_id_size = self.calculate_field_id_size()
return (
f"size.add_length({field_id_size}, this->{self.field_name}.size());"
)
return self._get_simple_size_calculation(name, force, "add_length")
# Check if this is being called from a repeated field context
# In that case, 'name' will be 'it' and we need to use the repeated version
if name == "it":
# For repeated fields, we need to use add_string_field_repeated which includes field ID
# For repeated fields, we need to use add_length_force which includes field ID
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_string_field_repeated(total_size, {field_id_size}, it);"
return f"size.add_length_force({field_id_size}, it.size());"
# For messages that need encoding, use the StringRef size
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_string_field(total_size, {field_id_size}, this->{self.field_name}_ref_.size());"
return f"size.add_length({field_id_size}, this->{self.field_name}_ref_.size());"
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical string
@@ -768,7 +803,7 @@ class BytesType(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return f"ProtoSize::add_bytes_field(total_size, {self.calculate_field_id_size()}, this->{self.field_name}_len_);"
return f"size.add_length({self.calculate_field_id_size()}, this->{self.field_name}_len_);"
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical bytes
@@ -842,14 +877,10 @@ class FixedArrayBytesType(TypeInfo):
field_id_size = self.calculate_field_id_size()
if force:
# For repeated fields, always calculate size
return f"total_size += {field_id_size} + ProtoSize::varint(static_cast<uint32_t>({length_field})) + {length_field};"
# For non-repeated fields, skip if length is 0 (matching encode_string behavior)
return (
f"if ({length_field} != 0) {{\n"
f" total_size += {field_id_size} + ProtoSize::varint(static_cast<uint32_t>({length_field})) + {length_field};\n"
f"}}"
)
# For repeated fields, always calculate size (no zero check)
return f"size.add_length_force({field_id_size}, {length_field});"
# For non-repeated fields, add_length already checks for zero
return f"size.add_length({field_id_size}, {length_field});"
def get_estimated_size(self) -> int:
# Estimate based on typical BLE advertisement size
@@ -876,7 +907,7 @@ class UInt32Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_uint32_field")
return self._get_simple_size_calculation(name, force, "add_uint32")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -912,7 +943,7 @@ class EnumType(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(
name, force, "add_enum_field", f"static_cast<uint32_t>({name})"
name, force, "add_uint32", f"static_cast<uint32_t>({name})"
)
def get_estimated_size(self) -> int:
@@ -934,7 +965,7 @@ class SFixed32Type(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_sfixed32_field(total_size, {field_id_size}, {name});"
return f"size.add_sfixed32({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 4
@@ -958,7 +989,7 @@ class SFixed64Type(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_sfixed64_field(total_size, {field_id_size}, {name});"
return f"size.add_sfixed64({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 8
@@ -981,7 +1012,7 @@ class SInt32Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_sint32_field")
return self._get_simple_size_calculation(name, force, "add_sint32")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -1001,7 +1032,7 @@ class SInt64Type(TypeInfo):
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_sint64_field")
return self._get_simple_size_calculation(name, force, "add_sint64")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@@ -1132,6 +1163,10 @@ class FixedArrayRepeatedType(TypeInfo):
class RepeatedTypeInfo(TypeInfo):
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
super().__init__(field)
# Check if this is a pointer field by looking for container_pointer option
self._container_type = get_field_opt(field, pb.container_pointer, "")
self._use_pointer = bool(self._container_type)
# For repeated fields, we need to get the base type info
# but we can't call create_field_type_info as it would cause recursion
# So we extract just the type creation logic
@@ -1147,6 +1182,13 @@ class RepeatedTypeInfo(TypeInfo):
@property
def cpp_type(self) -> str:
if self._use_pointer and self._container_type:
# For pointer fields, use the specified container type
# If the container type already includes the element type (e.g., std::set<climate::ClimateMode>)
# use it as-is, otherwise append the element type
if "<" in self._container_type and ">" in self._container_type:
return f"const {self._container_type}*"
return f"const {self._container_type}<{self._ti.cpp_type}>*"
return f"std::vector<{self._ti.cpp_type}>"
@property
@@ -1167,6 +1209,9 @@ class RepeatedTypeInfo(TypeInfo):
@property
def decode_varint_content(self) -> str:
# Pointer fields don't support decoding
if self._use_pointer:
return None
content = self._ti.decode_varint
if content is None:
return None
@@ -1176,6 +1221,9 @@ class RepeatedTypeInfo(TypeInfo):
@property
def decode_length_content(self) -> str:
# Pointer fields don't support decoding
if self._use_pointer:
return None
content = self._ti.decode_length
if content is None and isinstance(self._ti, MessageType):
# Special handling for non-template message decoding
@@ -1188,6 +1236,9 @@ class RepeatedTypeInfo(TypeInfo):
@property
def decode_32bit_content(self) -> str:
# Pointer fields don't support decoding
if self._use_pointer:
return None
content = self._ti.decode_32bit
if content is None:
return None
@@ -1197,6 +1248,9 @@ class RepeatedTypeInfo(TypeInfo):
@property
def decode_64bit_content(self) -> str:
# Pointer fields don't support decoding
if self._use_pointer:
return None
content = self._ti.decode_64bit
if content is None:
return None
@@ -1211,6 +1265,15 @@ class RepeatedTypeInfo(TypeInfo):
@property
def encode_content(self) -> str:
if self._use_pointer:
# For pointer fields, just dereference (pointer should never be null in our use case)
o = f"for (const auto &it : *this->{self.field_name}) {{\n"
if isinstance(self._ti, EnumType):
o += f" buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>(it), true);\n"
else:
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
o += "}"
return o
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
if isinstance(self._ti, EnumType):
o += f" buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>(it), true);\n"
@@ -1221,6 +1284,11 @@ class RepeatedTypeInfo(TypeInfo):
@property
def dump_content(self) -> str:
if self._use_pointer:
# For pointer fields, dereference and use the existing helper
return _generate_array_dump_content(
self._ti, f"*this->{self.field_name}", self.name, is_bool=False
)
return _generate_array_dump_content(
self._ti, f"this->{self.field_name}", self.name, is_bool=self._ti_is_bool
)
@@ -1231,29 +1299,34 @@ class RepeatedTypeInfo(TypeInfo):
def get_size_calculation(self, name: str, force: bool = False) -> str:
# For repeated fields, we always need to pass force=True to the underlying type's calculation
# This is because the encode method always sets force=true for repeated fields
# Handle message types separately as they use a dedicated helper
if isinstance(self._ti, MessageType):
# For repeated messages, use the dedicated helper that handles iteration internally
field_id_size = self._ti.calculate_field_id_size()
return (
f"ProtoSize::add_repeated_message(total_size, {field_id_size}, {name});"
)
container = f"*{name}" if self._use_pointer else name
return f"size.add_repeated_message({field_id_size}, {container});"
# For other repeated types, use the underlying type's size calculation with force=True
o = f"if (!{name}.empty()) {{\n"
# For non-message types, generate size calculation with iteration
container_ref = f"*{name}" if self._use_pointer else name
empty_check = f"{name}->empty()" if self._use_pointer else f"{name}.empty()"
# Check if this is a fixed-size type by seeing if it has a fixed byte count
o = f"if (!{empty_check}) {{\n"
# Check if this is a fixed-size type
num_bytes = self._ti.get_fixed_size_bytes()
if num_bytes is not None:
# Fixed types have constant size per element, so we can multiply
# Fixed types have constant size per element
field_id_size = self._ti.calculate_field_id_size()
# Pre-calculate the total bytes per element
bytes_per_element = field_id_size + num_bytes
o += f" total_size += {name}.size() * {bytes_per_element};\n"
size_expr = f"{name}->size()" if self._use_pointer else f"{name}.size()"
o += f" size.add_precalculated_size({size_expr} * {bytes_per_element});\n"
else:
# Other types need the actual value
o += f" for (const auto {'' if self._ti_is_bool else '&'}it : {name}) {{\n"
auto_ref = "" if self._ti_is_bool else "&"
o += f" for (const auto {auto_ref}it : {container_ref}) {{\n"
o += f" {self._ti.get_size_calculation('it', True)}\n"
o += " }\n"
o += "}"
return o
@@ -1680,11 +1753,11 @@ def build_message_type(
if needs_encode and encode:
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
o += f" {encode[0]} "
o += f" {encode[0]} }}\n"
else:
o += "\n"
o += indent("\n".join(encode)) + "\n"
o += "}\n"
o += "}\n"
cpp += o
prot = "void encode(ProtoWriteBuffer buffer) const override;"
public_content.append(prot)
@@ -1692,17 +1765,17 @@ def build_message_type(
# Add calculate_size method only if this message needs encoding and has fields
if needs_encode and size_calc:
o = f"void {desc.name}::calculate_size(uint32_t &total_size) const {{"
o = f"void {desc.name}::calculate_size(ProtoSize &size) const {{"
# For a single field, just inline it for simplicity
if len(size_calc) == 1 and len(size_calc[0]) + len(o) + 3 < 120:
o += f" {size_calc[0]} "
o += f" {size_calc[0]} }}\n"
else:
# For multiple fields
o += "\n"
o += indent("\n".join(size_calc)) + "\n"
o += "}\n"
o += "}\n"
cpp += o
prot = "void calculate_size(uint32_t &total_size) const override;"
prot = "void calculate_size(ProtoSize &size) const override;"
public_content.append(prot)
# If no fields to calculate size for or message doesn't need encoding, the default implementation in ProtoMessage will be used
@@ -1733,8 +1806,13 @@ def build_message_type(
if base_class:
out = f"class {desc.name} : public {base_class} {{\n"
else:
# Determine inheritance based on whether the message needs decoding
base_class = "ProtoDecodableMessage" if needs_decode else "ProtoMessage"
# Check if message has any non-deprecated fields
has_fields = any(not field.options.deprecated for field in desc.field)
# Determine inheritance based on whether the message needs decoding and has fields
if needs_decode and has_fields:
base_class = "ProtoDecodableMessage"
else:
base_class = "ProtoMessage"
out = f"class {desc.name} : public {base_class} {{\n"
out += " public:\n"
out += indent("\n".join(public_content)) + "\n"
@@ -2000,7 +2078,14 @@ def build_service_message_type(
hout += f"virtual void {func}(const {mt.name} &value){{}};\n"
case = ""
case += f"{mt.name} msg;\n"
case += "msg.decode(msg_data, msg_size);\n"
# Check if this message has any fields (excluding deprecated ones)
has_fields = any(not field.options.deprecated for field in mt.field)
if has_fields:
# Normal case: decode the message
case += "msg.decode(msg_data, msg_size);\n"
else:
# Empty message optimization: skip decode since there are no fields
case += "// Empty message: no decode needed\n"
if log:
case += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
@@ -2029,6 +2114,7 @@ def main() -> None:
d = descriptor.FileDescriptorSet.FromString(proto_content)
file = d.file[0]
content = FILE_HEADER
content += """\
#pragma once
@@ -2037,7 +2123,10 @@ def main() -> None:
#include "esphome/core/string_ref.h"
#include "proto.h"
#include "api_pb2_includes.h"
"""
content += """
namespace esphome::api {
"""