1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-06 13:22:19 +01:00

zero_copy_str

This commit is contained in:
J. Nick Koston
2025-07-21 17:02:01 -10:00
parent b17e2019c7
commit c4ac22286f
12 changed files with 1483 additions and 558 deletions

View File

@@ -340,6 +340,10 @@ def create_field_type_info(
if field.type == 12:
return BytesType(field, needs_decode, needs_encode)
# Special handling for string fields
if field.type == 9:
return StringType(field, needs_decode, needs_encode)
validate_field_type(field.type, field.name)
return TYPE_INFO[field.type](field)
@@ -540,12 +544,70 @@ class StringType(TypeInfo):
encode_func = "encode_string"
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
@property
def public_content(self) -> list[str]:
content: list[str] = []
# Add std::string storage if message needs decoding
if self._needs_decode:
content.append(f"std::string {self.field_name}{{}};")
if self._needs_encode:
content.extend(
[
# Add pointer/length fields if message needs encoding
f"const char* {self.field_name}_ptr_{{nullptr}};",
f"size_t {self.field_name}_len_{{0}};",
# Add setter method if message needs encoding
f"void set_{self.field_name}(const char* data, size_t len) {{",
f" this->{self.field_name}_ptr_ = data;",
f" this->{self.field_name}_len_ = len;",
"}",
]
)
return content
@property
def encode_content(self) -> str:
return f"buffer.encode_string({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);"
def dump(self, name):
o = f'out.append("\'").append({name}).append("\'");'
return o
# For SOURCE_CLIENT only, always use std::string
if not self._needs_encode:
return f'out.append("\'").append(this->{self.field_name}).append("\'");'
# For SOURCE_SERVER, always use pointer/length
if not self._needs_decode:
return (
f"if (this->{self.field_name}_ptr_ != nullptr) {{"
f' out.append("\'").append(this->{self.field_name}_ptr_).append("\'");'
f"}} else {{"
f' out.append("\'").append("").append("\'");'
f"}}"
)
# For SOURCE_BOTH, check if pointer is set (sending) or use string (received)
return (
f"if (this->{self.field_name}_ptr_ != nullptr) {{"
f' out.append("\'").append(this->{self.field_name}_ptr_).append("\'");'
f"}} else {{"
f' out.append("\'").append(this->{self.field_name}).append("\'");'
f"}}"
)
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "add_string_field")
# For SOURCE_CLIENT only messages, use the string field directly
if not self._needs_encode:
return self._get_simple_size_calculation(name, force, "add_string_field")
# Check if this is being called from a repeated field context
# In that case, 'name' will be 'it' and we need to use .length()
if name == "it":
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_string_field(total_size, {field_id_size}, it.length());"
# For messages that need encoding, use the length only
field_id_size = self.calculate_field_id_size()
return f"ProtoSize::add_string_field(total_size, {field_id_size}, this->{self.field_name}_len_);"
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical string