mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 15:12:06 +00:00 
			
		
		
		
	[api] Implement zero-copy for all protobuf bytes fields (#9761)
This commit is contained in:
		| @@ -732,7 +732,6 @@ message SubscribeLogsResponse { | ||||
|  | ||||
|   LogLevel level = 1; | ||||
|   bytes message = 3; | ||||
|   bool send_failed = 4; | ||||
| } | ||||
|  | ||||
| // ==================== NOISE ENCRYPTION ==================== | ||||
|   | ||||
| @@ -225,24 +225,16 @@ void APIConnection::loop() { | ||||
|   if (this->image_reader_ && this->image_reader_->available() && this->helper_->can_write_without_blocking()) { | ||||
|     uint32_t to_send = std::min((size_t) MAX_BATCH_PACKET_SIZE, this->image_reader_->available()); | ||||
|     bool done = this->image_reader_->available() == to_send; | ||||
|     uint32_t msg_size = 0; | ||||
|     ProtoSize::add_fixed_field<4>(msg_size, 1, true); | ||||
|     // partial message size calculated manually since its a special case | ||||
|     // 1 for the data field, varint for the data size, and the data itself | ||||
|     msg_size += 1 + ProtoSize::varint(to_send) + to_send; | ||||
|     ProtoSize::add_bool_field(msg_size, 1, done); | ||||
|  | ||||
|     auto buffer = this->create_buffer(msg_size); | ||||
|     // fixed32 key = 1; | ||||
|     buffer.encode_fixed32(1, camera::Camera::instance()->get_object_id_hash()); | ||||
|     // bytes data = 2; | ||||
|     buffer.encode_bytes(2, this->image_reader_->peek_data_buffer(), to_send); | ||||
|     // bool done = 3; | ||||
|     buffer.encode_bool(3, done); | ||||
|     CameraImageResponse msg; | ||||
|     msg.key = camera::Camera::instance()->get_object_id_hash(); | ||||
|     msg.set_data(this->image_reader_->peek_data_buffer(), to_send); | ||||
|     msg.done = done; | ||||
| #ifdef USE_DEVICES | ||||
|     msg.device_id = camera::Camera::instance()->get_device_id(); | ||||
| #endif | ||||
|  | ||||
|     bool success = this->send_buffer(buffer, CameraImageResponse::MESSAGE_TYPE); | ||||
|  | ||||
|     if (success) { | ||||
|     if (this->send_message_(msg, CameraImageResponse::MESSAGE_TYPE)) { | ||||
|       this->image_reader_->consume_data(to_send); | ||||
|       if (done) { | ||||
|         this->image_reader_->return_image(); | ||||
| @@ -1350,26 +1342,10 @@ void APIConnection::update_command(const UpdateCommandRequest &msg) { | ||||
| #endif | ||||
|  | ||||
| bool APIConnection::try_send_log_message(int level, const char *tag, const char *line, size_t message_len) { | ||||
|   // Pre-calculate message size to avoid reallocations | ||||
|   uint32_t msg_size = 0; | ||||
|  | ||||
|   // Add size for level field (field ID 1, varint type) | ||||
|   // 1 byte for field tag + size of the level varint | ||||
|   msg_size += 1 + api::ProtoSize::varint(static_cast<uint32_t>(level)); | ||||
|  | ||||
|   // Add size for string field (field ID 3, string type) | ||||
|   // 1 byte for field tag + size of length varint + string length | ||||
|   msg_size += 1 + api::ProtoSize::varint(static_cast<uint32_t>(message_len)) + message_len; | ||||
|  | ||||
|   // Create a pre-sized buffer | ||||
|   auto buffer = this->create_buffer(msg_size); | ||||
|  | ||||
|   // Encode the message (SubscribeLogsResponse) | ||||
|   buffer.encode_uint32(1, static_cast<uint32_t>(level));  // LogLevel level = 1 | ||||
|   buffer.encode_string(3, line, message_len);             // string message = 3 | ||||
|  | ||||
|   // SubscribeLogsResponse - 29 | ||||
|   return this->send_buffer(buffer, SubscribeLogsResponse::MESSAGE_TYPE); | ||||
|   SubscribeLogsResponse msg; | ||||
|   msg.level = static_cast<enums::LogLevel>(level); | ||||
|   msg.set_message(reinterpret_cast<const uint8_t *>(line), message_len); | ||||
|   return this->send_message_(msg, SubscribeLogsResponse::MESSAGE_TYPE); | ||||
| } | ||||
|  | ||||
| void APIConnection::complete_authentication_() { | ||||
|   | ||||
| @@ -822,13 +822,11 @@ bool SubscribeLogsRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { | ||||
| } | ||||
| void SubscribeLogsResponse::encode(ProtoWriteBuffer buffer) const { | ||||
|   buffer.encode_uint32(1, static_cast<uint32_t>(this->level)); | ||||
|   buffer.encode_bytes(3, reinterpret_cast<const uint8_t *>(this->message.data()), this->message.size()); | ||||
|   buffer.encode_bool(4, this->send_failed); | ||||
|   buffer.encode_bytes(3, this->message_ptr_, this->message_len_); | ||||
| } | ||||
| void SubscribeLogsResponse::calculate_size(uint32_t &total_size) const { | ||||
|   ProtoSize::add_enum_field(total_size, 1, static_cast<uint32_t>(this->level)); | ||||
|   ProtoSize::add_string_field(total_size, 1, this->message); | ||||
|   ProtoSize::add_bool_field(total_size, 1, this->send_failed); | ||||
|   ProtoSize::add_bytes_field(total_size, 1, this->message_len_); | ||||
| } | ||||
| #ifdef USE_API_NOISE | ||||
| bool NoiseEncryptionSetKeyRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { | ||||
| @@ -1034,7 +1032,7 @@ void ListEntitiesCameraResponse::calculate_size(uint32_t &total_size) const { | ||||
| } | ||||
| void CameraImageResponse::encode(ProtoWriteBuffer buffer) const { | ||||
|   buffer.encode_fixed32(1, this->key); | ||||
|   buffer.encode_bytes(2, reinterpret_cast<const uint8_t *>(this->data.data()), this->data.size()); | ||||
|   buffer.encode_bytes(2, this->data_ptr_, this->data_len_); | ||||
|   buffer.encode_bool(3, this->done); | ||||
| #ifdef USE_DEVICES | ||||
|   buffer.encode_uint32(4, this->device_id); | ||||
| @@ -1042,7 +1040,7 @@ void CameraImageResponse::encode(ProtoWriteBuffer buffer) const { | ||||
| } | ||||
| void CameraImageResponse::calculate_size(uint32_t &total_size) const { | ||||
|   ProtoSize::add_fixed32_field(total_size, 1, this->key); | ||||
|   ProtoSize::add_string_field(total_size, 1, this->data); | ||||
|   ProtoSize::add_bytes_field(total_size, 1, this->data_len_); | ||||
|   ProtoSize::add_bool_field(total_size, 1, this->done); | ||||
| #ifdef USE_DEVICES | ||||
|   ProtoSize::add_uint32_field(total_size, 1, this->device_id); | ||||
| @@ -1976,12 +1974,12 @@ bool BluetoothGATTReadRequest::decode_varint(uint32_t field_id, ProtoVarInt valu | ||||
| void BluetoothGATTReadResponse::encode(ProtoWriteBuffer buffer) const { | ||||
|   buffer.encode_uint64(1, this->address); | ||||
|   buffer.encode_uint32(2, this->handle); | ||||
|   buffer.encode_bytes(3, reinterpret_cast<const uint8_t *>(this->data.data()), this->data.size()); | ||||
|   buffer.encode_bytes(3, this->data_ptr_, this->data_len_); | ||||
| } | ||||
| void BluetoothGATTReadResponse::calculate_size(uint32_t &total_size) const { | ||||
|   ProtoSize::add_uint64_field(total_size, 1, this->address); | ||||
|   ProtoSize::add_uint32_field(total_size, 1, this->handle); | ||||
|   ProtoSize::add_string_field(total_size, 1, this->data); | ||||
|   ProtoSize::add_bytes_field(total_size, 1, this->data_len_); | ||||
| } | ||||
| bool BluetoothGATTWriteRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { | ||||
|   switch (field_id) { | ||||
| @@ -2064,12 +2062,12 @@ bool BluetoothGATTNotifyRequest::decode_varint(uint32_t field_id, ProtoVarInt va | ||||
| void BluetoothGATTNotifyDataResponse::encode(ProtoWriteBuffer buffer) const { | ||||
|   buffer.encode_uint64(1, this->address); | ||||
|   buffer.encode_uint32(2, this->handle); | ||||
|   buffer.encode_bytes(3, reinterpret_cast<const uint8_t *>(this->data.data()), this->data.size()); | ||||
|   buffer.encode_bytes(3, this->data_ptr_, this->data_len_); | ||||
| } | ||||
| void BluetoothGATTNotifyDataResponse::calculate_size(uint32_t &total_size) const { | ||||
|   ProtoSize::add_uint64_field(total_size, 1, this->address); | ||||
|   ProtoSize::add_uint32_field(total_size, 1, this->handle); | ||||
|   ProtoSize::add_string_field(total_size, 1, this->data); | ||||
|   ProtoSize::add_bytes_field(total_size, 1, this->data_len_); | ||||
| } | ||||
| void BluetoothConnectionsFreeResponse::encode(ProtoWriteBuffer buffer) const { | ||||
|   buffer.encode_uint32(1, this->free); | ||||
| @@ -2268,11 +2266,11 @@ bool VoiceAssistantAudio::decode_length(uint32_t field_id, ProtoLengthDelimited | ||||
|   return true; | ||||
| } | ||||
| void VoiceAssistantAudio::encode(ProtoWriteBuffer buffer) const { | ||||
|   buffer.encode_bytes(1, reinterpret_cast<const uint8_t *>(this->data.data()), this->data.size()); | ||||
|   buffer.encode_bytes(1, this->data_ptr_, this->data_len_); | ||||
|   buffer.encode_bool(2, this->end); | ||||
| } | ||||
| void VoiceAssistantAudio::calculate_size(uint32_t &total_size) const { | ||||
|   ProtoSize::add_string_field(total_size, 1, this->data); | ||||
|   ProtoSize::add_bytes_field(total_size, 1, this->data_len_); | ||||
|   ProtoSize::add_bool_field(total_size, 1, this->end); | ||||
| } | ||||
| bool VoiceAssistantTimerEventResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { | ||||
|   | ||||
| @@ -965,13 +965,17 @@ class SubscribeLogsRequest : public ProtoDecodableMessage { | ||||
| class SubscribeLogsResponse : public ProtoMessage { | ||||
|  public: | ||||
|   static constexpr uint8_t MESSAGE_TYPE = 29; | ||||
|   static constexpr uint8_t ESTIMATED_SIZE = 13; | ||||
|   static constexpr uint8_t ESTIMATED_SIZE = 11; | ||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | ||||
|   const char *message_name() const override { return "subscribe_logs_response"; } | ||||
| #endif | ||||
|   enums::LogLevel level{}; | ||||
|   std::string message{}; | ||||
|   bool send_failed{false}; | ||||
|   const uint8_t *message_ptr_{nullptr}; | ||||
|   size_t message_len_{0}; | ||||
|   void set_message(const uint8_t *data, size_t len) { | ||||
|     this->message_ptr_ = data; | ||||
|     this->message_len_ = len; | ||||
|   } | ||||
|   void encode(ProtoWriteBuffer buffer) const override; | ||||
|   void calculate_size(uint32_t &total_size) const override; | ||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | ||||
| @@ -1228,7 +1232,12 @@ class CameraImageResponse : public StateResponseProtoMessage { | ||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | ||||
|   const char *message_name() const override { return "camera_image_response"; } | ||||
| #endif | ||||
|   std::string data{}; | ||||
|   const uint8_t *data_ptr_{nullptr}; | ||||
|   size_t data_len_{0}; | ||||
|   void set_data(const uint8_t *data, size_t len) { | ||||
|     this->data_ptr_ = data; | ||||
|     this->data_len_ = len; | ||||
|   } | ||||
|   bool done{false}; | ||||
|   void encode(ProtoWriteBuffer buffer) const override; | ||||
|   void calculate_size(uint32_t &total_size) const override; | ||||
| @@ -1882,7 +1891,12 @@ class BluetoothGATTReadResponse : public ProtoMessage { | ||||
| #endif | ||||
|   uint64_t address{0}; | ||||
|   uint32_t handle{0}; | ||||
|   std::string data{}; | ||||
|   const uint8_t *data_ptr_{nullptr}; | ||||
|   size_t data_len_{0}; | ||||
|   void set_data(const uint8_t *data, size_t len) { | ||||
|     this->data_ptr_ = data; | ||||
|     this->data_len_ = len; | ||||
|   } | ||||
|   void encode(ProtoWriteBuffer buffer) const override; | ||||
|   void calculate_size(uint32_t &total_size) const override; | ||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | ||||
| @@ -1970,7 +1984,12 @@ class BluetoothGATTNotifyDataResponse : public ProtoMessage { | ||||
| #endif | ||||
|   uint64_t address{0}; | ||||
|   uint32_t handle{0}; | ||||
|   std::string data{}; | ||||
|   const uint8_t *data_ptr_{nullptr}; | ||||
|   size_t data_len_{0}; | ||||
|   void set_data(const uint8_t *data, size_t len) { | ||||
|     this->data_ptr_ = data; | ||||
|     this->data_len_ = len; | ||||
|   } | ||||
|   void encode(ProtoWriteBuffer buffer) const override; | ||||
|   void calculate_size(uint32_t &total_size) const override; | ||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | ||||
| @@ -2264,6 +2283,12 @@ class VoiceAssistantAudio : public ProtoDecodableMessage { | ||||
|   const char *message_name() const override { return "voice_assistant_audio"; } | ||||
| #endif | ||||
|   std::string data{}; | ||||
|   const uint8_t *data_ptr_{nullptr}; | ||||
|   size_t data_len_{0}; | ||||
|   void set_data(const uint8_t *data, size_t len) { | ||||
|     this->data_ptr_ = data; | ||||
|     this->data_len_ = len; | ||||
|   } | ||||
|   bool end{false}; | ||||
|   void encode(ProtoWriteBuffer buffer) const override; | ||||
|   void calculate_size(uint32_t &total_size) const override; | ||||
|   | ||||
| @@ -1668,11 +1668,7 @@ void SubscribeLogsResponse::dump_to(std::string &out) const { | ||||
|   out.append("\n"); | ||||
|  | ||||
|   out.append("  message: "); | ||||
|   out.append(format_hex_pretty(this->message)); | ||||
|   out.append("\n"); | ||||
|  | ||||
|   out.append("  send_failed: "); | ||||
|   out.append(YESNO(this->send_failed)); | ||||
|   out.append(format_hex_pretty(this->message_ptr_, this->message_len_)); | ||||
|   out.append("\n"); | ||||
|   out.append("}"); | ||||
| } | ||||
| @@ -1681,7 +1677,7 @@ void NoiseEncryptionSetKeyRequest::dump_to(std::string &out) const { | ||||
|   __attribute__((unused)) char buffer[64]; | ||||
|   out.append("NoiseEncryptionSetKeyRequest {\n"); | ||||
|   out.append("  key: "); | ||||
|   out.append(format_hex_pretty(this->key)); | ||||
|   out.append(format_hex_pretty(reinterpret_cast<const uint8_t *>(this->key.data()), this->key.size())); | ||||
|   out.append("\n"); | ||||
|   out.append("}"); | ||||
| } | ||||
| @@ -1934,7 +1930,7 @@ void CameraImageResponse::dump_to(std::string &out) const { | ||||
|   out.append("\n"); | ||||
|  | ||||
|   out.append("  data: "); | ||||
|   out.append(format_hex_pretty(this->data)); | ||||
|   out.append(format_hex_pretty(this->data_ptr_, this->data_len_)); | ||||
|   out.append("\n"); | ||||
|  | ||||
|   out.append("  done: "); | ||||
| @@ -3143,7 +3139,7 @@ void BluetoothGATTReadResponse::dump_to(std::string &out) const { | ||||
|   out.append("\n"); | ||||
|  | ||||
|   out.append("  data: "); | ||||
|   out.append(format_hex_pretty(this->data)); | ||||
|   out.append(format_hex_pretty(this->data_ptr_, this->data_len_)); | ||||
|   out.append("\n"); | ||||
|   out.append("}"); | ||||
| } | ||||
| @@ -3165,7 +3161,7 @@ void BluetoothGATTWriteRequest::dump_to(std::string &out) const { | ||||
|   out.append("\n"); | ||||
|  | ||||
|   out.append("  data: "); | ||||
|   out.append(format_hex_pretty(this->data)); | ||||
|   out.append(format_hex_pretty(reinterpret_cast<const uint8_t *>(this->data.data()), this->data.size())); | ||||
|   out.append("\n"); | ||||
|   out.append("}"); | ||||
| } | ||||
| @@ -3197,7 +3193,7 @@ void BluetoothGATTWriteDescriptorRequest::dump_to(std::string &out) const { | ||||
|   out.append("\n"); | ||||
|  | ||||
|   out.append("  data: "); | ||||
|   out.append(format_hex_pretty(this->data)); | ||||
|   out.append(format_hex_pretty(reinterpret_cast<const uint8_t *>(this->data.data()), this->data.size())); | ||||
|   out.append("\n"); | ||||
|   out.append("}"); | ||||
| } | ||||
| @@ -3233,7 +3229,7 @@ void BluetoothGATTNotifyDataResponse::dump_to(std::string &out) const { | ||||
|   out.append("\n"); | ||||
|  | ||||
|   out.append("  data: "); | ||||
|   out.append(format_hex_pretty(this->data)); | ||||
|   out.append(format_hex_pretty(this->data_ptr_, this->data_len_)); | ||||
|   out.append("\n"); | ||||
|   out.append("}"); | ||||
| } | ||||
| @@ -3487,7 +3483,11 @@ void VoiceAssistantAudio::dump_to(std::string &out) const { | ||||
|   __attribute__((unused)) char buffer[64]; | ||||
|   out.append("VoiceAssistantAudio {\n"); | ||||
|   out.append("  data: "); | ||||
|   out.append(format_hex_pretty(this->data)); | ||||
|   if (this->data_ptr_ != nullptr) { | ||||
|     out.append(format_hex_pretty(this->data_ptr_, this->data_len_)); | ||||
|   } else { | ||||
|     out.append(format_hex_pretty(reinterpret_cast<const uint8_t *>(this->data.data()), this->data.size())); | ||||
|   } | ||||
|   out.append("\n"); | ||||
|  | ||||
|   out.append("  end: "); | ||||
|   | ||||
| @@ -527,25 +527,6 @@ class ProtoSize { | ||||
|     total_size += field_id_size + 1; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * @brief Calculates and adds the size of a fixed field to the total message size | ||||
|    * | ||||
|    * Fixed fields always take exactly N bytes (4 for fixed32/float, 8 for fixed64/double). | ||||
|    * | ||||
|    * @tparam NumBytes The number of bytes for this fixed field (4 or 8) | ||||
|    * @param is_nonzero Whether the value is non-zero | ||||
|    */ | ||||
|   template<uint32_t NumBytes> | ||||
|   static inline void add_fixed_field(uint32_t &total_size, uint32_t field_id_size, bool is_nonzero) { | ||||
|     // Skip calculation if value is zero | ||||
|     if (!is_nonzero) { | ||||
|       return;  // No need to update total_size | ||||
|     } | ||||
|  | ||||
|     // Fixed fields always take exactly NumBytes | ||||
|     total_size += field_id_size + NumBytes; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * @brief Calculates and adds the size of a float field to the total message size | ||||
|    */ | ||||
| @@ -704,6 +685,19 @@ class ProtoSize { | ||||
|     total_size += field_id_size + varint(str_size) + str_size; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * @brief Calculates and adds the size of a bytes field to the total message size | ||||
|    */ | ||||
|   static inline void add_bytes_field(uint32_t &total_size, uint32_t field_id_size, size_t len) { | ||||
|     // Skip calculation if bytes is empty | ||||
|     if (len == 0) { | ||||
|       return;  // No need to update total_size | ||||
|     } | ||||
|  | ||||
|     // Field ID + length varint + data bytes | ||||
|     total_size += field_id_size + varint(static_cast<uint32_t>(len)) + static_cast<uint32_t>(len); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * @brief Calculates and adds the size of a nested message field to the total message size | ||||
|    * | ||||
|   | ||||
| @@ -234,9 +234,7 @@ bool BluetoothConnection::gattc_event_handler(esp_gattc_cb_event_t event, esp_ga | ||||
|       api::BluetoothGATTReadResponse resp; | ||||
|       resp.address = this->address_; | ||||
|       resp.handle = param->read.handle; | ||||
|       resp.data.reserve(param->read.value_len); | ||||
|       // Use bulk insert instead of individual push_backs | ||||
|       resp.data.insert(resp.data.end(), param->read.value, param->read.value + param->read.value_len); | ||||
|       resp.set_data(param->read.value, param->read.value_len); | ||||
|       this->proxy_->get_api_connection()->send_message(resp, api::BluetoothGATTReadResponse::MESSAGE_TYPE); | ||||
|       break; | ||||
|     } | ||||
| @@ -287,9 +285,7 @@ bool BluetoothConnection::gattc_event_handler(esp_gattc_cb_event_t event, esp_ga | ||||
|       api::BluetoothGATTNotifyDataResponse resp; | ||||
|       resp.address = this->address_; | ||||
|       resp.handle = param->notify.handle; | ||||
|       resp.data.reserve(param->notify.value_len); | ||||
|       // Use bulk insert instead of individual push_backs | ||||
|       resp.data.insert(resp.data.end(), param->notify.value, param->notify.value + param->notify.value_len); | ||||
|       resp.set_data(param->notify.value, param->notify.value_len); | ||||
|       this->proxy_->get_api_connection()->send_message(resp, api::BluetoothGATTNotifyDataResponse::MESSAGE_TYPE); | ||||
|       break; | ||||
|     } | ||||
|   | ||||
| @@ -273,7 +273,7 @@ void VoiceAssistant::loop() { | ||||
|         size_t read_bytes = this->ring_buffer_->read((void *) this->send_buffer_, SEND_BUFFER_SIZE, 0); | ||||
|         if (this->audio_mode_ == AUDIO_MODE_API) { | ||||
|           api::VoiceAssistantAudio msg; | ||||
|           msg.data.assign((char *) this->send_buffer_, read_bytes); | ||||
|           msg.set_data(this->send_buffer_, read_bytes); | ||||
|           this->api_client_->send_message(msg, api::VoiceAssistantAudio::MESSAGE_TYPE); | ||||
|         } else { | ||||
|           if (!this->udp_socket_running_) { | ||||
|   | ||||
| @@ -113,8 +113,15 @@ def force_str(force: bool) -> str: | ||||
| class TypeInfo(ABC): | ||||
|     """Base class for all type information.""" | ||||
|  | ||||
|     def __init__(self, field: descriptor.FieldDescriptorProto) -> None: | ||||
|     def __init__( | ||||
|         self, | ||||
|         field: descriptor.FieldDescriptorProto, | ||||
|         needs_decode: bool = True, | ||||
|         needs_encode: bool = True, | ||||
|     ) -> None: | ||||
|         self._field = field | ||||
|         self._needs_decode = needs_decode | ||||
|         self._needs_encode = needs_encode | ||||
|  | ||||
|     @property | ||||
|     def default_value(self) -> str: | ||||
| @@ -313,7 +320,11 @@ def validate_field_type(field_type: int, field_name: str = "") -> None: | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo: | ||||
| def create_field_type_info( | ||||
|     field: descriptor.FieldDescriptorProto, | ||||
|     needs_decode: bool = True, | ||||
|     needs_encode: bool = True, | ||||
| ) -> TypeInfo: | ||||
|     """Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options.""" | ||||
|     if field.label == 3:  # repeated | ||||
|         return RepeatedTypeInfo(field) | ||||
| @@ -325,6 +336,10 @@ def create_field_type_info(field: descriptor.FieldDescriptorProto) -> TypeInfo: | ||||
|     ): | ||||
|         return FixedArrayBytesType(field, fixed_size) | ||||
|  | ||||
|     # Special handling for bytes fields | ||||
|     if field.type == 12: | ||||
|         return BytesType(field, needs_decode, needs_encode) | ||||
|  | ||||
|     validate_field_type(field.type, field.name) | ||||
|     return TYPE_INFO[field.type](field) | ||||
|  | ||||
| @@ -589,20 +604,59 @@ class BytesType(TypeInfo): | ||||
|     default_value = "" | ||||
|     reference_type = "std::string &" | ||||
|     const_reference_type = "const std::string &" | ||||
|     decode_length = "value.as_string()" | ||||
|     encode_func = "encode_bytes" | ||||
|     decode_length = "value.as_string()" | ||||
|     wire_type = WireType.LENGTH_DELIMITED  # Uses wire type 2 | ||||
|  | ||||
|     @property | ||||
|     def public_content(self) -> list[str]: | ||||
|         content: list[str] = [] | ||||
|         # Add std::string storage if message needs decoding | ||||
|         if self._needs_decode: | ||||
|             content.append(f"std::string {self.field_name}{{}};") | ||||
|  | ||||
|         if self._needs_encode: | ||||
|             content.extend( | ||||
|                 [ | ||||
|                     # Add pointer/length fields if message needs encoding | ||||
|                     f"const uint8_t* {self.field_name}_ptr_{{nullptr}};", | ||||
|                     f"size_t {self.field_name}_len_{{0}};", | ||||
|                     # Add setter method if message needs encoding | ||||
|                     f"void set_{self.field_name}(const uint8_t* data, size_t len) {{", | ||||
|                     f"  this->{self.field_name}_ptr_ = data;", | ||||
|                     f"  this->{self.field_name}_len_ = len;", | ||||
|                     "}", | ||||
|                 ] | ||||
|             ) | ||||
|         return content | ||||
|  | ||||
|     @property | ||||
|     def encode_content(self) -> str: | ||||
|         return f"buffer.encode_bytes({self.number}, reinterpret_cast<const uint8_t*>(this->{self.field_name}.data()), this->{self.field_name}.size());" | ||||
|         return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);" | ||||
|  | ||||
|     def dump(self, name: str) -> str: | ||||
|         o = f"out.append(format_hex_pretty({name}));" | ||||
|         return o | ||||
|         ptr_dump = f"format_hex_pretty(this->{self.field_name}_ptr_, this->{self.field_name}_len_)" | ||||
|         str_dump = f"format_hex_pretty(reinterpret_cast<const uint8_t*>(this->{self.field_name}.data()), this->{self.field_name}.size())" | ||||
|  | ||||
|         # For SOURCE_CLIENT only, always use std::string | ||||
|         if not self._needs_encode: | ||||
|             return f"out.append({str_dump});" | ||||
|  | ||||
|         # For SOURCE_SERVER, always use pointer/length | ||||
|         if not self._needs_decode: | ||||
|             return f"out.append({ptr_dump});" | ||||
|  | ||||
|         # For SOURCE_BOTH, check if pointer is set (sending) or use string (received) | ||||
|         return ( | ||||
|             f"if (this->{self.field_name}_ptr_ != nullptr) {{\n" | ||||
|             f"    out.append({ptr_dump});\n" | ||||
|             f"  }} else {{\n" | ||||
|             f"    out.append({str_dump});\n" | ||||
|             f"  }}" | ||||
|         ) | ||||
|  | ||||
|     def get_size_calculation(self, name: str, force: bool = False) -> str: | ||||
|         return self._get_simple_size_calculation(name, force, "add_string_field") | ||||
|         return f"ProtoSize::add_bytes_field(total_size, {self.calculate_field_id_size()}, this->{self.field_name}_len_);" | ||||
|  | ||||
|     def get_estimated_size(self) -> int: | ||||
|         return self.calculate_field_id_size() + 8  # field ID + 8 bytes typical bytes | ||||
| @@ -1257,7 +1311,7 @@ def build_message_type( | ||||
|         if field.options.deprecated: | ||||
|             continue | ||||
|  | ||||
|         ti = create_field_type_info(field) | ||||
|         ti = create_field_type_info(field, needs_decode, needs_encode) | ||||
|  | ||||
|         # Skip field declarations for fields that are in the base class | ||||
|         # but include their encode/decode logic | ||||
| @@ -1572,10 +1626,20 @@ def build_base_class( | ||||
|     public_content = [] | ||||
|     protected_content = [] | ||||
|  | ||||
|     # Determine if any message using this base class needs decoding/encoding | ||||
|     needs_decode = any( | ||||
|         message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT) | ||||
|         for msg in messages | ||||
|     ) | ||||
|     needs_encode = any( | ||||
|         message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_SERVER) | ||||
|         for msg in messages | ||||
|     ) | ||||
|  | ||||
|     # For base classes, we only declare the fields but don't handle encode/decode | ||||
|     # The derived classes will handle encoding/decoding with their specific field numbers | ||||
|     for field in common_fields: | ||||
|         ti = create_field_type_info(field) | ||||
|         ti = create_field_type_info(field, needs_decode, needs_encode) | ||||
|  | ||||
|         # Get field_ifdef if it's consistent across all messages | ||||
|         field_ifdef = get_common_field_ifdef(field.name, messages) | ||||
| @@ -1586,12 +1650,6 @@ def build_base_class( | ||||
|         if ti.public_content: | ||||
|             public_content.extend(wrap_with_ifdef(ti.public_content, field_ifdef)) | ||||
|  | ||||
|     # Determine if any message using this base class needs decoding | ||||
|     needs_decode = any( | ||||
|         message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT) | ||||
|         for msg in messages | ||||
|     ) | ||||
|  | ||||
|     # Build header | ||||
|     parent_class = "ProtoDecodableMessage" if needs_decode else "ProtoMessage" | ||||
|     out = f"class {base_class_name} : public {parent_class} {{\n" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user