mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 07:03:55 +00:00 
			
		
		
		
	[api] Implement zero-copy string optimization for outgoing protobuf messages (#9790)
Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
This commit is contained in:
		| @@ -343,6 +343,10 @@ def create_field_type_info( | ||||
|     if field.type == 12: | ||||
|         return BytesType(field, needs_decode, needs_encode) | ||||
|  | ||||
|     # Special handling for string fields | ||||
|     if field.type == 9: | ||||
|         return StringType(field, needs_decode, needs_encode) | ||||
|  | ||||
|     validate_field_type(field.type, field.name) | ||||
|     return TYPE_INFO[field.type](field) | ||||
|  | ||||
| @@ -543,12 +547,67 @@ class StringType(TypeInfo): | ||||
|     encode_func = "encode_string" | ||||
|     wire_type = WireType.LENGTH_DELIMITED  # Uses wire type 2 | ||||
|  | ||||
|     @property | ||||
|     def public_content(self) -> list[str]: | ||||
|         content: list[str] = [] | ||||
|         # Add std::string storage if message needs decoding | ||||
|         if self._needs_decode: | ||||
|             content.append(f"std::string {self.field_name}{{}};") | ||||
|  | ||||
|         if self._needs_encode: | ||||
|             content.extend( | ||||
|                 [ | ||||
|                     # Add StringRef field if message needs encoding | ||||
|                     f"StringRef {self.field_name}_ref_{{}};", | ||||
|                     # Add setter method if message needs encoding | ||||
|                     f"void set_{self.field_name}(const StringRef &ref) {{", | ||||
|                     f"  this->{self.field_name}_ref_ = ref;", | ||||
|                     "}", | ||||
|                 ] | ||||
|             ) | ||||
|         return content | ||||
|  | ||||
|     @property | ||||
|     def encode_content(self) -> str: | ||||
|         return f"buffer.encode_string({self.number}, this->{self.field_name}_ref_);" | ||||
|  | ||||
|     def dump(self, name): | ||||
|         o = f'out.append("\'").append({name}).append("\'");' | ||||
|         return o | ||||
|         # If name is 'it', this is a repeated field element - always use string | ||||
|         if name == "it": | ||||
|             return "append_quoted_string(out, StringRef(it));" | ||||
|  | ||||
|         # For SOURCE_CLIENT only, always use std::string | ||||
|         if not self._needs_encode: | ||||
|             return f'out.append("\'").append(this->{self.field_name}).append("\'");' | ||||
|  | ||||
|         # For SOURCE_SERVER, always use StringRef | ||||
|         if not self._needs_decode: | ||||
|             return f"append_quoted_string(out, this->{self.field_name}_ref_);" | ||||
|  | ||||
|         # For SOURCE_BOTH, check if StringRef is set (sending) or use string (received) | ||||
|         return ( | ||||
|             f"if (!this->{self.field_name}_ref_.empty()) {{" | ||||
|             f'  out.append("\'").append(this->{self.field_name}_ref_.c_str()).append("\'");' | ||||
|             f"}} else {{" | ||||
|             f'  out.append("\'").append(this->{self.field_name}).append("\'");' | ||||
|             f"}}" | ||||
|         ) | ||||
|  | ||||
|     def get_size_calculation(self, name: str, force: bool = False) -> str: | ||||
|         return self._get_simple_size_calculation(name, force, "add_string_field") | ||||
|         # For SOURCE_CLIENT only messages, use the string field directly | ||||
|         if not self._needs_encode: | ||||
|             return self._get_simple_size_calculation(name, force, "add_string_field") | ||||
|  | ||||
|         # Check if this is being called from a repeated field context | ||||
|         # In that case, 'name' will be 'it' and we need to use the repeated version | ||||
|         if name == "it": | ||||
|             # For repeated fields, we need to use add_string_field_repeated which includes field ID | ||||
|             field_id_size = self.calculate_field_id_size() | ||||
|             return f"ProtoSize::add_string_field_repeated(total_size, {field_id_size}, it);" | ||||
|  | ||||
|         # For messages that need encoding, use the StringRef size | ||||
|         field_id_size = self.calculate_field_id_size() | ||||
|         return f"ProtoSize::add_string_field(total_size, {field_id_size}, this->{self.field_name}_ref_.size());" | ||||
|  | ||||
|     def get_estimated_size(self) -> int: | ||||
|         return self.calculate_field_id_size() + 8  # field ID + 8 bytes typical string | ||||
| @@ -1902,6 +1961,7 @@ def main() -> None: | ||||
| #pragma once | ||||
|  | ||||
| #include "esphome/core/defines.h" | ||||
| #include "esphome/core/string_ref.h" | ||||
|  | ||||
| #include "proto.h" | ||||
|  | ||||
| @@ -1935,6 +1995,15 @@ namespace api { | ||||
| namespace esphome { | ||||
| namespace api { | ||||
|  | ||||
| // Helper function to append a quoted string, handling empty StringRef | ||||
| static inline void append_quoted_string(std::string &out, const StringRef &ref) { | ||||
|   out.append("'"); | ||||
|   if (!ref.empty()) { | ||||
|     out.append(ref.c_str()); | ||||
|   } | ||||
|   out.append("'"); | ||||
| } | ||||
|  | ||||
| """ | ||||
|  | ||||
|     content += "namespace enums {\n\n" | ||||
| @@ -2174,7 +2243,13 @@ static const char *const TAG = "api.service"; | ||||
|             cpp += f"#ifdef {ifdef}\n" | ||||
|  | ||||
|         hpp_protected += f"  void {on_func}(const {inp} &msg) override;\n" | ||||
|         hpp += f"  virtual {ret} {func}(const {inp} &msg) = 0;\n" | ||||
|  | ||||
|         # For non-void methods, generate a send_ method instead of return-by-value | ||||
|         if is_void: | ||||
|             hpp += f"  virtual void {func}(const {inp} &msg) = 0;\n" | ||||
|         else: | ||||
|             hpp += f"  virtual bool send_{func}_response(const {inp} &msg) = 0;\n" | ||||
|  | ||||
|         cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n" | ||||
|  | ||||
|         # Start with authentication/connection check if needed | ||||
| @@ -2192,10 +2267,7 @@ static const char *const TAG = "api.service"; | ||||
|             if is_void: | ||||
|                 handler_body = f"this->{func}(msg);\n" | ||||
|             else: | ||||
|                 handler_body = f"{ret} ret = this->{func}(msg);\n" | ||||
|                 handler_body += ( | ||||
|                     f"if (!this->send_message(ret, {ret}::MESSAGE_TYPE)) {{\n" | ||||
|                 ) | ||||
|                 handler_body = f"if (!this->send_{func}_response(msg)) {{\n" | ||||
|                 handler_body += "  this->on_fatal_error();\n" | ||||
|                 handler_body += "}\n" | ||||
|  | ||||
| @@ -2207,8 +2279,7 @@ static const char *const TAG = "api.service"; | ||||
|             if is_void: | ||||
|                 body += f"this->{func}(msg);\n" | ||||
|             else: | ||||
|                 body += f"{ret} ret = this->{func}(msg);\n" | ||||
|                 body += f"if (!this->send_message(ret, {ret}::MESSAGE_TYPE)) {{\n" | ||||
|                 body += f"if (!this->send_{func}_response(msg)) {{\n" | ||||
|                 body += "  this->on_fatal_error();\n" | ||||
|                 body += "}\n" | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user